天天看點

pytorch實作mnist上的簡單的GAN

看李宏毅老師的視訊,實作下簡單的mnist上面的GAN。李宏毅老師課程

使用的是非常簡單的架構,一個生成器,一個判别器。

每個batch

  • 生成器生成圖檔,使得圖檔可以騙過判别器
  • 判别器對真圖檔和生成圖檔做二分類
    pytorch實作mnist上的簡單的GAN
    代碼如下:參考github
import os
import numpy as np

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch
import time

epochs = 200
batch_size = 512
lr = 0.0002
b1 = 0.5
b2 = 0.999
latent_dim = 2 # 生成器輸入次元
img_size = 28
channels = 1
sample_interval = 500 # 從中間過程中采樣

img_shape = (channels,img_size,img_size) # 圖像大小

# 是否使用GPU
cuda = True if torch.cuda.is_available() else False

# 定義生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 32, normalize=False),
            *block(32, 64),
            *block(64, 128),
            *block(128, 256),
            nn.Linear(256, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

        # shape = (1,2,3)
        # out = np.prod(shape)
        # 6
        # 将輸出的(batch_size, np.prod(shape))向量
        # 轉化為(batch_size, shape) 的圖檔
        # b = torch.tensor([[1,2,3,4,7,8],[4,5,6,7,9,10]])
        # c = b.view(b.size(0),*shape)
        #tensor([[[[ 1,  2,  3],
        #   [ 4,  7,  8]]],
        # [[[ 4,  5,  6],
        #   [ 7,  9, 10]]]])
        # view 從最後一個參數開始,是以b.view(2,1,2,3) 得到每行3個像素([ 4,  5,  6]),每列
        # 2個像素([[ 4,  5,  6],[ 7,  9, 10]]),一個通道([[[ 4,  5,  6],[ 7,  9, 10]]])
        # 的batch_size張圖檔

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

# 損失函數
adversarial_loss = torch.nn.BCELoss()

# 建立兩個模型
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# 加載mnist資料
# 建立檔案夾,linux用os.mkdir("../data/mnist",exist_ok=True)
# exist_ok = True 表示目錄存在時不報錯
os.makedirs("..\\data\\mnist",exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "..\\data\\mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor(),
             transforms.Normalize([0.5], [0.5])]
        )
    ),
    batch_size=batch_size,
    shuffle=True
)

# 優化器
optimizer_G = torch.optim.Adam(generator.parameters(),lr=lr,betas=(b1,b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr=lr,betas=(b1,b2))

# 資料類型
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

time_start = time.time()
# 訓練網絡
if __name__ == '__main__':
    for epoch in range(epochs):
        for i,(imgs,label) in enumerate(dataloader):
            # 建立标簽,二分類:采樣圖像全1,生成圖像全0
            valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

            real_imgs = Variable(imgs.type(Tensor))

            # 訓練生成器
            optimizer_G.zero_grad()
            # 生成随機噪聲
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
            # 生成圖像
            gen_imgs = generator(z)
            # 使得生成的圖像騙過判别器-> 二分類損失
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)
            g_loss.backward()
            optimizer_G.step()

            # 訓練判别器
            # 判别器應該将真的采樣圖像分類為真,生成的圖像分類為假
            # 注意要使用detach()函數将生成圖像從計算圖中分離
            # 這樣判别器的損失不會反向傳播到生成器哪裡
            optimizer_D.zero_grad()
            real_loss = adversarial_loss(discriminator(real_imgs),valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()),fake)
            d_loss= (real_loss+fake_loss) / 2.
            d_loss.backward()
            optimizer_D.step()

            if i % 100 == 1:
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                    % (epoch, epochs, i, len(dataloader), d_loss.item(), g_loss.item())
                )

            batches_done = epoch * len(dataloader) + i
            if batches_done % sample_interval == 0:
                save_image(gen_imgs.data[:25], "images%d.png" % batches_done, nrow=5, normalize=True)
        end_time = time.time()
        print('cost time:', end_time - time_start)
    torch.save(generator,'gan_g.pkl')

           

另外,代碼中還有值得說明的是

LeakyRelu

。傳統的

Relu

對所有的負數輸出都是0,但是負數部分可能會有用,

LeakyRelu

為負數部分添加微小的斜率,使得負半軸也可以被區分。其數學公式為:

LeakyReLU ( x ) = max ⁡ ( 0 , x ) + negative_slope ∗ min ⁡ ( 0 , x ) \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x) LeakyReLU(x)=max(0,x)+negative_slope∗min(0,x)

import torch
import matplotlib.pyplot as plt
input = torch.linspace(-2,2,100)
leakyRelu = torch.nn.LeakyReLU(0.2)
output = leakyRelu(input)
plt.plot(input,output)
           
pytorch實作mnist上的簡單的GAN

運作上一段GAN的代碼會得到訓練的中間過程中網絡生成的圖像,最後一張如下,還比較清楚。

pytorch實作mnist上的簡單的GAN

我們使用下面的一段代碼将所有的圖檔制作成gif:

import imageio
import os
import re

def create_gif(source,name,duration):
    frames = []
    for img in source:
        # 利用正規表達式篩選出圖檔
        if re.search("images[0-9]*.png",img):
            frames.append(imageio.imread(img))
    imageio.mimsave(name,frames,'GIF',duration=duration)
    print("%d IMG 2 GIF: %s DONE!" % (len(frames),name))


def main(img_path,gif_name):
    path = os.chdir(img_path)
    pic_list = os.listdir()
    duration_time = 0.2
    create_gif(pic_list,gif_name,duration_time)

main(".\\","gan.gif")

           

将訓練過程中的圖檔制作成gif如下:

pytorch實作mnist上的簡單的GAN

另外,你可能注意到了我的Generator的輸入是2維的,比較小。這是為了我可以做出一張2d的圖。在2維平面中展示平面中的點到生成圖像間的對應關系。

model = Generator().cuda()
model = torch.load('gan_g.pkl')
model.eval()
ones = np.ones(latent_dim)
input = []
for r in torch.linspace(-2.5,2.5,20):
    for c in torch.linspace(-2.5,2.5,20):
        input.append([r,c])
input = Variable(torch.tensor(input,dtype=torch.float32).cuda())
gen_imgs = model(input)
save_image(gen_imgs.data, "res.png", nrow=20, normalize=True)
           
pytorch實作mnist上的簡單的GAN

最後,mnist的資料加載時normalize使用的是[0.5,0.5] ,不是[0.1307, 0.3081],因為參考代碼裡面用的是0.5。我試了下第二種,結果會有很多噪點。

繼續閱讀