天天看点

pytorch实现mnist上的简单的GAN

看李宏毅老师的视频,实现下简单的mnist上面的GAN。李宏毅老师课程

使用的是非常简单的架构,一个生成器,一个判别器。

每个batch

  • 生成器生成图片,使得图片可以骗过判别器
  • 判别器对真图片和生成图片做二分类
    pytorch实现mnist上的简单的GAN
    代码如下:参考github
import os
import numpy as np

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch
import time

epochs = 200
batch_size = 512
lr = 0.0002
b1 = 0.5
b2 = 0.999
latent_dim = 2 # 生成器输入维度
img_size = 28
channels = 1
sample_interval = 500 # 从中间过程中采样

img_shape = (channels,img_size,img_size) # 图像大小

# 是否使用GPU
cuda = True if torch.cuda.is_available() else False

# 定义生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 32, normalize=False),
            *block(32, 64),
            *block(64, 128),
            *block(128, 256),
            nn.Linear(256, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

        # shape = (1,2,3)
        # out = np.prod(shape)
        # 6
        # 将输出的(batch_size, np.prod(shape))向量
        # 转化为(batch_size, shape) 的图片
        # b = torch.tensor([[1,2,3,4,7,8],[4,5,6,7,9,10]])
        # c = b.view(b.size(0),*shape)
        #tensor([[[[ 1,  2,  3],
        #   [ 4,  7,  8]]],
        # [[[ 4,  5,  6],
        #   [ 7,  9, 10]]]])
        # view 从最后一个参数开始,所以b.view(2,1,2,3) 得到每行3个像素([ 4,  5,  6]),每列
        # 2个像素([[ 4,  5,  6],[ 7,  9, 10]]),一个通道([[[ 4,  5,  6],[ 7,  9, 10]]])
        # 的batch_size张图片

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

# 损失函数
adversarial_loss = torch.nn.BCELoss()

# 创建两个模型
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# 加载mnist数据
# 创建文件夹,linux用os.mkdir("../data/mnist",exist_ok=True)
# exist_ok = True 表示目录存在时不报错
os.makedirs("..\\data\\mnist",exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "..\\data\\mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor(),
             transforms.Normalize([0.5], [0.5])]
        )
    ),
    batch_size=batch_size,
    shuffle=True
)

# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(),lr=lr,betas=(b1,b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr=lr,betas=(b1,b2))

# 数据类型
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

time_start = time.time()
# 训练网络
if __name__ == '__main__':
    for epoch in range(epochs):
        for i,(imgs,label) in enumerate(dataloader):
            # 创建标签,二分类:采样图像全1,生成图像全0
            valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

            real_imgs = Variable(imgs.type(Tensor))

            # 训练生成器
            optimizer_G.zero_grad()
            # 生成随机噪声
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
            # 生成图像
            gen_imgs = generator(z)
            # 使得生成的图像骗过判别器-> 二分类损失
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)
            g_loss.backward()
            optimizer_G.step()

            # 训练判别器
            # 判别器应该将真的采样图像分类为真,生成的图像分类为假
            # 注意要使用detach()函数将生成图像从计算图中分离
            # 这样判别器的损失不会反向传播到生成器哪里
            optimizer_D.zero_grad()
            real_loss = adversarial_loss(discriminator(real_imgs),valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()),fake)
            d_loss= (real_loss+fake_loss) / 2.
            d_loss.backward()
            optimizer_D.step()

            if i % 100 == 1:
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                    % (epoch, epochs, i, len(dataloader), d_loss.item(), g_loss.item())
                )

            batches_done = epoch * len(dataloader) + i
            if batches_done % sample_interval == 0:
                save_image(gen_imgs.data[:25], "images%d.png" % batches_done, nrow=5, normalize=True)
        end_time = time.time()
        print('cost time:', end_time - time_start)
    torch.save(generator,'gan_g.pkl')

           

另外,代码中还有值得说明的是

LeakyRelu

。传统的

Relu

对所有的负数输出都是0,但是负数部分可能会有用,

LeakyRelu

为负数部分添加微小的斜率,使得负半轴也可以被区分。其数学公式为:

LeakyReLU ( x ) = max ⁡ ( 0 , x ) + negative_slope ∗ min ⁡ ( 0 , x ) \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x) LeakyReLU(x)=max(0,x)+negative_slope∗min(0,x)

import torch
import matplotlib.pyplot as plt
input = torch.linspace(-2,2,100)
leakyRelu = torch.nn.LeakyReLU(0.2)
output = leakyRelu(input)
plt.plot(input,output)
           
pytorch实现mnist上的简单的GAN

运行上一段GAN的代码会得到训练的中间过程中网络生成的图像,最后一张如下,还比较清楚。

pytorch实现mnist上的简单的GAN

我们使用下面的一段代码将所有的图片制作成gif:

import imageio
import os
import re

def create_gif(source,name,duration):
    frames = []
    for img in source:
        # 利用正则表达式筛选出图片
        if re.search("images[0-9]*.png",img):
            frames.append(imageio.imread(img))
    imageio.mimsave(name,frames,'GIF',duration=duration)
    print("%d IMG 2 GIF: %s DONE!" % (len(frames),name))


def main(img_path,gif_name):
    path = os.chdir(img_path)
    pic_list = os.listdir()
    duration_time = 0.2
    create_gif(pic_list,gif_name,duration_time)

main(".\\","gan.gif")

           

将训练过程中的图片制作成gif如下:

pytorch实现mnist上的简单的GAN

另外,你可能注意到了我的Generator的输入是2维的,比较小。这是为了我可以做出一张2d的图。在2维平面中展示平面中的点到生成图像间的对应关系。

model = Generator().cuda()
model = torch.load('gan_g.pkl')
model.eval()
ones = np.ones(latent_dim)
input = []
for r in torch.linspace(-2.5,2.5,20):
    for c in torch.linspace(-2.5,2.5,20):
        input.append([r,c])
input = Variable(torch.tensor(input,dtype=torch.float32).cuda())
gen_imgs = model(input)
save_image(gen_imgs.data, "res.png", nrow=20, normalize=True)
           
pytorch实现mnist上的简单的GAN

最后,mnist的数据加载时normalize使用的是[0.5,0.5] ,不是[0.1307, 0.3081],因为参考代码里面用的是0.5。我试了下第二种,结果会有很多噪点。

继续阅读