天天看點

Pytorch學習——GAN——MINST

對于GAN的原理,我這裡就不多講了,網上很多。這裡主要講代碼,以及調試的踩得坑。

本文參考:

https://blog.csdn.net/qxqsunshine/article/details/84105948

首先導入相關的包。

import torch
import torchvision
import torch.utils.data
import torch.nn
import torch.autograd.variable
from torch.autograd import Variable
from torchvision.utils import save_image
           

資料的處理

transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
           

在網上其他代碼中,通常會加入如下的代碼:

但是會出現通道不比對的錯誤提示:

RuntimeError: output with shape [1, 28, 28] doesn’t match the broadcast shape [3, 28, 28]。

原因是因為MINST資料集是單通道的圖像(灰階圖),但是Normalize均值方差都是三個。解決上述錯誤,一般會加入如下代碼:

這樣問題可以暫時解決,但後面又會出現元素數量不對的情況:

RuntimeError: size mismatch, m1: [100 x 2352], m2: [784 x 256] at C:/w/1/s/windows/pytorch/aten/src\THC/generic/THCTensorMathBlas.cu:268

很煩。是以索性不用Normalize,這樣所有的問題完美解決。。。

加載資料集

#資料集
test_data=torchvision.datasets.MNIST(
    root='./data/',#路徑
    transform=transform,#資料處理
    train=False,#使用測試集,這個看心情
    download=True#下載下傳
)
           

将資料放進加載器,作用就是包裝一下。

以下句子來自:

https://www.cnblogs.com/demo-deng/p/10623334.html

資料加載器,結合了資料集和取樣器,并且可以提供多個線程處理資料集。在訓練模型時使用到此函數,用來把訓練資料分成多個小組,此函數每次抛出一組資料。直至把所有的資料都抛出。就是做一個資料的初始化。

#資料加載器, DataLoader就是用來包裝所使用的資料,每次抛出一批資料
test_data_load=torch.utils.data.DataLoader(
    dataset=test_data,
    shuffle=True,#每次打亂順序
    batch_size=100#批大小,這裡根據資料的樣本數量而定,最好是能整除
)
           

向量轉圖檔。

其中view()函數,我了解的還不夠透徹。。。

#将1*784向量,轉換成28*28圖檔
def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out
           

生成器

采用線性網絡,ReLU()作為激活函數,最後一程=層使用Tanh()。至于為什麼這麼寫,隻是說實驗結果比較好,我沒有仔細研究這個。

#生成器
class Generater(torch.nn.Module):
    def __init__(self):
        super(Generater, self).__init__();
        self.G_lay1=torch.nn.Sequential(
            torch.nn.Linear(100,128),
            torch.nn.ReLU(),
            torch.nn.Linear(128,256),
            torch.nn.ReLU(),
            torch.nn.Linear(256,784),
            torch.nn.Tanh()
        )
    def forward(self, x):
        return self.G_lay1(x)
           

判别器

使用LeakyReLU作為激活函數,最後一層Sigmoid們可以了解為二分類器。另外判别器和生成器最好是對稱的。

#判别器
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.D_lay2=torch.nn.Sequential(
            torch.nn.Linear(784,256),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Linear(256,128),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Linear(128,1),
            torch.nn.Sigmoid()
        )
    def forward(self, x):
        return self.D_lay2(x)

           

訓練前的準備

#執行個體化
g_net=Generater().cuda()
d_net=Discriminator().cuda()

#損失函數,優化器
loss_fun=torch.nn.BCELoss()
g_optimizer=torch.optim.Adam(g_net.parameters(),lr=0.0002,betas=(0.5,0.999))
d_optimizer=torch.optim.Adam(d_net.parameters(),lr=0.0002,betas=(0.5,0.999))
epoch_n=20

           

這裡采用beta1為0.5,在多數的時候一般會使用0.9,這個還是要看具體情況。

我這裡隻train了20次,電腦太渣。

訓練

for epoch in range(epoch_n):
    for i,(img,_) in enumerate(test_data_load):
    
        img_num=img.size(0)
        #訓練D——————————————————————————————————————————————————
        #真圖——真标簽
        img=img.view(img_num,-1)
        real_img=Variable(img).cuda()
        real_output=d_net(real_img)
        real_lab= Variable(torch.ones(img_num)).cuda()
        real_loss=loss_fun(real_output,real_lab)
        
        #假圖——假标簽
        noise=Variable(torch.randn(img_num, 100)).cuda()
        fake_img=g_net(noise)
        fake_lab=Variable(torch.zeros(img_num)).cuda()
        fake_output=d_net(fake_img)
        fake_loss=loss_fun(fake_output,fake_lab)

        d_loss=real_loss+fake_loss
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        #訓練G————————————————————————————————————————————
        #假圖——真标簽
        g_noise = Variable(torch.randn(img_num, 100)).cuda()
        g_img = g_net(noise)
        g_lab=Variable(torch.ones(img_num)).cuda()
        g_output=d_net(g_img)
        g_loss=loss_fun(g_output,g_lab)

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
                  'real_loss: {:.6f}, fake_loss: {:.6f}'.format(
                epoch, epoch_n, d_loss.data.item(), g_loss.data.item(),
                real_loss.data.mean(), fake_loss.data.mean()))

    if epoch == 0:
        real_images = to_img(real_img.cpu().data)
        save_image(real_images, './img/real_images.png')

    #儲存生成的圖檔
    fake_images = to_img(fake_img.cpu().data)
    save_image(fake_images,'./img/fake_images-{}.png'.format(epoch + 1),nrow=10)

#儲存模型
torch.save(g_net.state_dict(), './generator.pth')
torch.save(d_net.state_dict(), './discriminator.pth')

           

其中save_image中nrow,表示将一個批次的100張圖,按照每行10個排列。

整體代碼:

import torch
import torchvision
import torch.utils.data
import torch.nn
import torch.autograd.variable
from torch.autograd import Variable
from torchvision.utils import save_image

#圖像讀入與處理
transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),

])

#資料集
test_data=torchvision.datasets.MNIST(
    root='./data/',#路徑
    transform=transform,#資料處理
    train=False,#使用測試集,這個看心情
    download=True#下載下傳
)

#資料加載器, DataLoader就是用來包裝所使用的資料,每次抛出一批資料
test_data_load=torch.utils.data.DataLoader(
    dataset=test_data,
    shuffle=True,#每次打亂順序
    batch_size=100#批大小,這裡根據資料的樣本數量而定,最好是能整除
)

#将1*784向量,轉換成28*28圖檔
def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out

#生成器
class Generater(torch.nn.Module):
    def __init__(self):
        super(Generater, self).__init__();
        self.G_lay1=torch.nn.Sequential(
            torch.nn.Linear(100,128),
            torch.nn.ReLU(),
            torch.nn.Linear(128,256),
            torch.nn.ReLU(),
            torch.nn.Linear(256,784),
            torch.nn.Tanh()
        )
    def forward(self, x):
        return self.G_lay1(x)

#判别器
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.D_lay2=torch.nn.Sequential(
            torch.nn.Linear(784,256),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Linear(256,128),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Linear(128,1),
            torch.nn.Sigmoid()
        )
    def forward(self, x):
        return self.D_lay2(x)

#執行個體化
g_net=Generater().cuda()
d_net=Discriminator().cuda()

#損失函數,優化器
loss_fun=torch.nn.BCELoss()
g_optimizer=torch.optim.Adam(g_net.parameters(),lr=0.0002,betas=(0.5,0.999))
d_optimizer=torch.optim.Adam(d_net.parameters(),lr=0.0002,betas=(0.5,0.999))
epoch_n=20

for epoch in range(epoch_n):
    for i,(img,_) in enumerate(test_data_load):

        img_num=img.size(0)
        #訓練D——————————————————————————————————————————————————
        #真圖——真标簽
        img=img.view(img_num,-1)
        real_img=Variable(img).cuda()
        real_output=d_net(real_img)
        real_lab= Variable(torch.ones(img_num)).cuda()
        real_loss=loss_fun(real_output,real_lab)

        #假圖——假标簽
        noise=Variable(torch.randn(img_num, 100)).cuda()
        fake_img=g_net(noise)
        fake_lab=Variable(torch.zeros(img_num)).cuda()
        fake_output=d_net(fake_img)
        fake_loss=loss_fun(fake_output,fake_lab)

        d_loss=real_loss+fake_loss
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        #訓練G————————————————————————————————————————————
        #假圖——真标簽
        g_noise = Variable(torch.randn(img_num, 100)).cuda()
        g_img = g_net(noise)
        g_lab=Variable(torch.ones(img_num)).cuda()
        g_output=d_net(g_img)
        g_loss=loss_fun(g_output,g_lab)

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
                  'real_loss: {:.6f}, fake_loss: {:.6f}'.format(
                epoch, epoch_n, d_loss.data.item(), g_loss.data.item(),
                real_loss.data.mean(), fake_loss.data.mean()))

    if epoch == 0:
        real_images = to_img(real_img.cpu().data)
        save_image(real_images, './img/real_images.png')

    #儲存生成的圖檔
    fake_images = to_img(fake_img.cpu().data)
    save_image(fake_images,'./img/fake_images-{}.png'.format(epoch + 1),nrow=10)

torch.save(g_net.state_dict(), './generator.pth')
torch.save(d_net.state_dict(), './discriminator.pth')

           

以上代碼,并非全部原創,隻是站在巨人肩膀上,如果侵權,請評論或私信聯系 ,如果有錯誤,請評論或私信,謝謝。

繼續閱讀