天天看点

Pytorch学习——GAN——MINST

对于GAN的原理,我这里就不多讲了,网上很多。这里主要讲代码,以及调试的踩得坑。

本文参考:

https://blog.csdn.net/qxqsunshine/article/details/84105948

首先导入相关的包。

import torch
import torchvision
import torch.utils.data
import torch.nn
import torch.autograd.variable
from torch.autograd import Variable
from torchvision.utils import save_image
           

数据的处理

transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
           

在网上其他代码中,通常会加入如下的代码:

但是会出现通道不匹配的错误提示:

RuntimeError: output with shape [1, 28, 28] doesn’t match the broadcast shape [3, 28, 28]。

原因是因为MINST数据集是单通道的图像(灰度图),但是Normalize均值方差都是三个。解决上述错误,一般会加入如下代码:

这样问题可以暂时解决,但后面又会出现元素数量不对的情况:

RuntimeError: size mismatch, m1: [100 x 2352], m2: [784 x 256] at C:/w/1/s/windows/pytorch/aten/src\THC/generic/THCTensorMathBlas.cu:268

很烦。所以索性不用Normalize,这样所有的问题完美解决。。。

加载数据集

#数据集
test_data=torchvision.datasets.MNIST(
    root='./data/',#路径
    transform=transform,#数据处理
    train=False,#使用测试集,这个看心情
    download=True#下载
)
           

将数据放进加载器,作用就是包装一下。

以下句子来自:

https://www.cnblogs.com/demo-deng/p/10623334.html

数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。

#数据加载器, DataLoader就是用来包装所使用的数据,每次抛出一批数据
test_data_load=torch.utils.data.DataLoader(
    dataset=test_data,
    shuffle=True,#每次打乱顺序
    batch_size=100#批大小,这里根据数据的样本数量而定,最好是能整除
)
           

向量转图片。

其中view()函数,我理解的还不够透彻。。。

#将1*784向量,转换成28*28图片
def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out
           

生成器

采用线性网络,ReLU()作为激活函数,最后一程=层使用Tanh()。至于为什么这么写,只是说实验结果比较好,我没有仔细研究这个。

#生成器
class Generater(torch.nn.Module):
    def __init__(self):
        super(Generater, self).__init__();
        self.G_lay1=torch.nn.Sequential(
            torch.nn.Linear(100,128),
            torch.nn.ReLU(),
            torch.nn.Linear(128,256),
            torch.nn.ReLU(),
            torch.nn.Linear(256,784),
            torch.nn.Tanh()
        )
    def forward(self, x):
        return self.G_lay1(x)
           

判别器

使用LeakyReLU作为激活函数,最后一层Sigmoid们可以理解为二分类器。另外判别器和生成器最好是对称的。

#判别器
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.D_lay2=torch.nn.Sequential(
            torch.nn.Linear(784,256),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Linear(256,128),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Linear(128,1),
            torch.nn.Sigmoid()
        )
    def forward(self, x):
        return self.D_lay2(x)

           

训练前的准备

#实例化
g_net=Generater().cuda()
d_net=Discriminator().cuda()

#损失函数,优化器
loss_fun=torch.nn.BCELoss()
g_optimizer=torch.optim.Adam(g_net.parameters(),lr=0.0002,betas=(0.5,0.999))
d_optimizer=torch.optim.Adam(d_net.parameters(),lr=0.0002,betas=(0.5,0.999))
epoch_n=20

           

这里采用beta1为0.5,在多数的时候一般会使用0.9,这个还是要看具体情况。

我这里只train了20次,电脑太渣。

训练

for epoch in range(epoch_n):
    for i,(img,_) in enumerate(test_data_load):
    
        img_num=img.size(0)
        #训练D——————————————————————————————————————————————————
        #真图——真标签
        img=img.view(img_num,-1)
        real_img=Variable(img).cuda()
        real_output=d_net(real_img)
        real_lab= Variable(torch.ones(img_num)).cuda()
        real_loss=loss_fun(real_output,real_lab)
        
        #假图——假标签
        noise=Variable(torch.randn(img_num, 100)).cuda()
        fake_img=g_net(noise)
        fake_lab=Variable(torch.zeros(img_num)).cuda()
        fake_output=d_net(fake_img)
        fake_loss=loss_fun(fake_output,fake_lab)

        d_loss=real_loss+fake_loss
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        #训练G————————————————————————————————————————————
        #假图——真标签
        g_noise = Variable(torch.randn(img_num, 100)).cuda()
        g_img = g_net(noise)
        g_lab=Variable(torch.ones(img_num)).cuda()
        g_output=d_net(g_img)
        g_loss=loss_fun(g_output,g_lab)

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
                  'real_loss: {:.6f}, fake_loss: {:.6f}'.format(
                epoch, epoch_n, d_loss.data.item(), g_loss.data.item(),
                real_loss.data.mean(), fake_loss.data.mean()))

    if epoch == 0:
        real_images = to_img(real_img.cpu().data)
        save_image(real_images, './img/real_images.png')

    #保存生成的图片
    fake_images = to_img(fake_img.cpu().data)
    save_image(fake_images,'./img/fake_images-{}.png'.format(epoch + 1),nrow=10)

#保存模型
torch.save(g_net.state_dict(), './generator.pth')
torch.save(d_net.state_dict(), './discriminator.pth')

           

其中save_image中nrow,表示将一个批次的100张图,按照每行10个排列。

整体代码:

import torch
import torchvision
import torch.utils.data
import torch.nn
import torch.autograd.variable
from torch.autograd import Variable
from torchvision.utils import save_image

#图像读入与处理
transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),

])

#数据集
test_data=torchvision.datasets.MNIST(
    root='./data/',#路径
    transform=transform,#数据处理
    train=False,#使用测试集,这个看心情
    download=True#下载
)

#数据加载器, DataLoader就是用来包装所使用的数据,每次抛出一批数据
test_data_load=torch.utils.data.DataLoader(
    dataset=test_data,
    shuffle=True,#每次打乱顺序
    batch_size=100#批大小,这里根据数据的样本数量而定,最好是能整除
)

#将1*784向量,转换成28*28图片
def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out

#生成器
class Generater(torch.nn.Module):
    def __init__(self):
        super(Generater, self).__init__();
        self.G_lay1=torch.nn.Sequential(
            torch.nn.Linear(100,128),
            torch.nn.ReLU(),
            torch.nn.Linear(128,256),
            torch.nn.ReLU(),
            torch.nn.Linear(256,784),
            torch.nn.Tanh()
        )
    def forward(self, x):
        return self.G_lay1(x)

#判别器
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.D_lay2=torch.nn.Sequential(
            torch.nn.Linear(784,256),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Linear(256,128),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Linear(128,1),
            torch.nn.Sigmoid()
        )
    def forward(self, x):
        return self.D_lay2(x)

#实例化
g_net=Generater().cuda()
d_net=Discriminator().cuda()

#损失函数,优化器
loss_fun=torch.nn.BCELoss()
g_optimizer=torch.optim.Adam(g_net.parameters(),lr=0.0002,betas=(0.5,0.999))
d_optimizer=torch.optim.Adam(d_net.parameters(),lr=0.0002,betas=(0.5,0.999))
epoch_n=20

for epoch in range(epoch_n):
    for i,(img,_) in enumerate(test_data_load):

        img_num=img.size(0)
        #训练D——————————————————————————————————————————————————
        #真图——真标签
        img=img.view(img_num,-1)
        real_img=Variable(img).cuda()
        real_output=d_net(real_img)
        real_lab= Variable(torch.ones(img_num)).cuda()
        real_loss=loss_fun(real_output,real_lab)

        #假图——假标签
        noise=Variable(torch.randn(img_num, 100)).cuda()
        fake_img=g_net(noise)
        fake_lab=Variable(torch.zeros(img_num)).cuda()
        fake_output=d_net(fake_img)
        fake_loss=loss_fun(fake_output,fake_lab)

        d_loss=real_loss+fake_loss
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        #训练G————————————————————————————————————————————
        #假图——真标签
        g_noise = Variable(torch.randn(img_num, 100)).cuda()
        g_img = g_net(noise)
        g_lab=Variable(torch.ones(img_num)).cuda()
        g_output=d_net(g_img)
        g_loss=loss_fun(g_output,g_lab)

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
                  'real_loss: {:.6f}, fake_loss: {:.6f}'.format(
                epoch, epoch_n, d_loss.data.item(), g_loss.data.item(),
                real_loss.data.mean(), fake_loss.data.mean()))

    if epoch == 0:
        real_images = to_img(real_img.cpu().data)
        save_image(real_images, './img/real_images.png')

    #保存生成的图片
    fake_images = to_img(fake_img.cpu().data)
    save_image(fake_images,'./img/fake_images-{}.png'.format(epoch + 1),nrow=10)

torch.save(g_net.state_dict(), './generator.pth')
torch.save(d_net.state_dict(), './discriminator.pth')

           

以上代码,并非全部原创,只是站在巨人肩膀上,如果侵权,请评论或私信联系 ,如果有错误,请评论或私信,谢谢。

继续阅读