天天看點

深度學習架構PyTorch入門與實踐:第七章 AI插畫師:生成對抗網絡

生成對抗網絡(Generative Adversarial Net,GAN)是近年來深度學習中一個十分熱門的方向,卷積網絡之父、深度學習元老級人物LeCun Yan就曾說過“GAN is the most interesting idea in the last 10 years in machine learning”。尤其是近兩年,GAN的論文呈現井噴的趨勢,GitHub上有人收集了各種各樣的GAN變種、應用、研究論文等,其中有名稱的多達數百篇**[the-gan-zoo]**。作者還統計了GAN論文發表數目随時間變化的趨勢,如下圖所示,足見GAN的火爆程度。本節将簡要介紹GAN的基本原理,并帶領讀者實作一個簡單的生成對抗網絡,用以生成動漫人物的頭像。

深度學習架構PyTorch入門與實踐:第七章 AI插畫師:生成對抗網絡

7.1 GAN的原理簡介

GAN的開山之作是被稱為“GAN之父”的Ian Goodfellow發表于2014年的經典論文《Generative Adversarial Networks》,在這篇論文中他提出了生成對抗網絡,并設計了第一個GAN實驗——手寫數字生成。

GAN的産生來自于一個靈機一動的想法:

“What I cannot create, I do not understand.”(那些我所不能創造的,我也沒有真正了解它。)—— Richard Feynman

類似地。如果深度學習不能創造圖檔,那麼它也沒有真正地了解圖檔。當時深度學習已經開始在各類計算機視覺領域中攻城略地,在幾乎所有任務中都取得了突破。但是人們一直對神經網絡的黑盒模型表示質疑,于是越來越多的人從可視化的角度探索卷積網絡所學習的特征和特征間的組合,而GAN則從生成學習角度展示了神經網絡的強大能力。GAN解決了非監督學習中的著名問題:給定一批樣本,訓練一個系統能夠生成類似的樣本。

生成對抗網絡的網絡結構如下圖所示,主要包含以下兩個子網絡:

  • 生成器(generator):輸入一個随機噪聲,生成一張圖檔。
  • 判别器(discriminator):判斷輸入的圖檔是真圖檔還是假圖檔。
深度學習架構PyTorch入門與實踐:第七章 AI插畫師:生成對抗網絡

訓練判别器時,需要利用生成器生成的假圖檔和來自真實世界的真圖檔;訓練生成器時,隻用噪聲生成假圖檔。判别器用來評估生成的假圖檔的品質,促使生成器相應地調整參數。

生成器的目标是盡可能地生成以假亂真的圖檔,讓判别器以為這是真的圖檔;判别器的目标是将生成器生成的圖檔和真實世界的圖檔區分開。可以看出這二者的目标相反,在訓練過程中互相對抗,這也是它被稱為生成對抗網絡的原因。

上面的描述可能有點抽象,讓我們用收藏齊白石作品(齊白石作品如下圖所示)的書畫收藏家和假畫販子的例子來說明。假畫販子相當于是生成器,他們希望能夠模仿大師真迹僞造出以假亂真的假畫,騙過收藏家,進而賣出高價;書畫收藏家則希望将赝品和真迹區分開,讓真迹流傳于世。齊白石畫蝦可以說是畫壇一絕,曆來為世人所追捧。

深度學習架構PyTorch入門與實踐:第七章 AI插畫師:生成對抗網絡

在這個例子中,一開始假畫販子和書畫收藏家都是新手,他們對真迹和赝品的概念都很模糊。假畫販子仿造出來的假畫幾乎都是随機塗鴉,而書畫收藏家的鑒定能力很差,有不少赝品被他當成真迹,也有許多真迹被當成赝品。

首先,書畫收藏家收集了一大堆市面上的赝品和齊白石大師的真迹,仔細研究對比,初步學習了畫中蝦的結構,明白畫中的生物形狀彎曲,并且有一對類似鉗子的“螯足”,對于不符合這個條件的假畫全部過濾掉。當收藏家用這個标準到市場上進行鑒定,假畫基本無法騙過收藏家,假畫販子損失慘重。但是假畫販子自己仿造的赝品中,還是有一些蒙騙過關,這些蒙騙過關的赝品中都有彎曲的形狀,并且有一對類似鉗子的“螯足”。于是假畫販子開始修改仿造的手法,在仿造的作品中加入彎曲的形狀和一對類似鉗子的“螯足”。除了這些特點,其他地方例如顔色、線條都是随機畫的。假畫販子制造出的第一版赝品如下所示。

深度學習架構PyTorch入門與實踐:第七章 AI插畫師:生成對抗網絡

當假畫販子把這些畫拿到市面上去賣時,很容易就騙過了收藏家,因為畫中有一隻彎曲的生物,生物前面有一對類似鉗子的東西,符合收藏家認定的真迹的标準,是以收藏家就把它當成真迹買回來。随機時間的推移,收藏家買回來越來越多的假畫,損失慘重,于是他又閉門研究赝品和真迹之間的差別,經過反複比較對比,他發現齊白石畫蝦的真迹中除了有彎曲的形狀、蝦的觸須蔓長,通身作半透明狀,并且畫的蝦的細節十分豐富,蝦的每一節之間均呈白色狀。

收藏家學成之後,重新出山,而假畫販子的仿造技法沒有提升,所制造出來的赝品被收藏家輕松識破。于是假畫販子也開始嘗試不同的畫蝦手法,大多都是徒勞無功,不過在衆多嘗試之中,還是有一些赝品騙過了收藏家的眼睛。假畫販子發現這些仿制的赝品觸須蔓長,通身作半透明狀,并且畫的蝦的細節十分豐富,如下所示。于是假畫販子開始大量仿造這種畫,并拿到市面上銷售,許多都成功地騙過了收藏家。

深度學習架構PyTorch入門與實踐:第七章 AI插畫師:生成對抗網絡

收藏家再度損失慘重,被迫關門研究齊白石的真迹和赝品之間的差別,學習齊白石真迹的特點,提升自己的鑒定能力。就這樣,通過收藏家和假畫販子之間的博弈,收藏家從零開始慢慢提升了自己對真迹和赝品的鑒别能力,而假畫販子也不斷地提高自己仿造齊白石真迹的水準。收藏家利用假畫販子提供的赝品,作為和真迹的對比,對齊白石畫蝦真迹有了更好的鑒賞能力;而假畫販子也不斷嘗試,提升仿造水準,提升仿造假畫的品質,即使最後制造出來的仍屬于赝品,但是和真迹相比也很接近了。收藏家和假畫販子二者之間互相博弈對抗,同時又不斷促使着對方學習進步,達到共同提升的目的。

在這個例子中,假畫販子相當于一個生成器,收藏家相當于一個判别器。一開始生成器和判别器的水準都很差,因為二者都是随機初始化的。訓練過程分為兩步交替進行,第一步是訓練判别器(隻修改判别器的參數,固定生成器),目标是把真迹和赝品區分開;第二步是訓練生成器(隻修改生成器的參數,固定判别器),為的是生成的假畫能夠被判别器判别為真迹(被收藏家認為是真迹)。這兩步交替進行,進而生成器和判别器都達到了一個很高的水準。訓練到最後,生成的蝦的圖檔如下所示,和齊白石的真迹幾乎沒有差别。

深度學習架構PyTorch入門與實踐:第七章 AI插畫師:生成對抗網絡

下面我們來思考網絡結構的設計。判别器的目标是判斷輸入的圖檔是真迹還是赝品,是以可以看成是一個二分類網絡,參考第6章中Dogs vs. Cats的實驗,我們可以設計一個簡單的卷積網絡。生成器的目标是從噪聲中生成一張彩色圖檔,這裡我們采用廣泛使用的DCGAN(Deep Convolutional Generative Adversarial Networks)結構,即采用全卷積網絡,其結構如下所示。網絡的輸入是一個100維的噪聲,輸出的是一個3 * 64 * 64的圖檔。這裡的輸入可以看成是一個100 * 1 * 1的圖檔,通過上卷積慢慢增大為4 * 4、8 * 8、16 * 16、32 * 32和64 * 64。上卷積,或稱為轉置卷積,是一種特殊的卷積操作,類似于卷積操作的逆運算。當卷積的stride為2時,輸出相比輸入會下采樣到一半的尺寸;而當上卷積的stride為2時,輸出會上采樣到輸入的兩倍尺寸。這種上采樣的做法可以了解為圖檔的資訊儲存于100個向量之中,神經網絡根據這100個向量描述的資訊,前幾步的上采樣先勾勒出輪廓、色調等基礎資訊,後幾步上采樣慢慢完善細節。網絡越深,細節越詳細。

深度學習架構PyTorch入門與實踐:第七章 AI插畫師:生成對抗網絡

在DCGAN中,判别器的結構和生成器對稱:生成器中采用上采樣的卷積,判别器就采用下采樣的卷積,生成器時根據噪聲輸出一張64 * 64 * 3的圖檔,而判别器則是根據輸入的64 * 64 * 3的圖檔輸出圖檔屬于正負樣本的分數(機率)。

7.2 用GAN生成動漫頭像

本章所有代碼及圖檔資料百度網盤下載下傳,提取碼:b5da。

本節将用GAN實作一個生成動漫人物頭像的例子。在日本的技術部落格網站上有個部落客(估計是一位二次元的愛好者)

@mattya,利用DCGAN從20萬張動漫頭像中學習,最終能夠利用程式自動生成動漫頭像,生成的圖檔效果如下圖所示。源程式是利用Chainer架構實作的,本節我們嘗試利用PyTorch實作。

深度學習架構PyTorch入門與實踐:第七章 AI插畫師:生成對抗網絡

原始的圖檔是從網站中爬取的,并利用OpenCV從中截取頭像,處理起來比較麻煩。這裡我們使用知乎使用者何之源爬取并經過處理的5萬張圖檔。可從本書配套程式的README.MD的百度網盤連結下載下傳所有的圖檔壓縮包,并解壓到指定的檔案夾中。需要注意的是,這裡的圖檔的分辨率是3 * 96 * 96,而不是論文中的3 * 64 * 64,是以需要相應地調整網絡結構,使生成圖像的尺寸為96。

我們先來看本實驗的代碼結構。

checkpoints/    # 無代碼,用來儲存模型
imgs/    # 無代碼,用來儲存生成的圖檔
data/    # 無代碼,用來儲存訓練所需的圖檔
main.py    # 訓練和生成
model.py    # 模型定義
visualize.py    # 可視化工具visdom的封裝
requirements.txt    # 程式中用到的第三方庫
README.MD    # 說明
           

接着來看model.py中是如何定義生成器的。

# coding:utf8
from torch import nn


class NetG(nn.Module):
    """
    生成器定義
    """

    def __init__(self, opt):
        super(NetG, self).__init__()
        ngf = opt.ngf  # 生成器feature map數

        self.main = nn.Sequential(
            # 輸入是一個nz次元的噪聲,我們可以認為它是一個1*1*nz的feature map
            nn.ConvTranspose2d(opt.nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 上一步的輸出形狀:(ngf*8) x 4 x 4

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 上一步的輸出形狀: (ngf*4) x 8 x 8

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 上一步的輸出形狀: (ngf*2) x 16 x 16

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 上一步的輸出形狀:(ngf) x 32 x 32

            nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
            nn.Tanh()  # 輸出範圍 -1~1 故而采用Tanh
            # 輸出形狀:3 x 96 x 96
        )

    def forward(self, input):
        return self.main(input)

           

可以看出生成器的搭建相對比較簡單,直接使用nn.Sequential将上卷積、激活、池化等操作拼接起來即可,這裡需要注意上卷積ConvTranspose2d的使用。當kernel_size為4,stride為2,padding為1時,根據公式 H o u t = ( H i n − 1 ) − 2 ∗ p a d d i n g + k e r n e l _ s i z e H_{out} = ( H_{in} - 1 ) - 2 * padding + kernel\_size Hout​=(Hin​−1)−2∗padding+kernel_size,輸出尺寸剛好變成輸入的兩倍。最後一層采用kernel_size為5,stride為3,padding為1,是為了将32 * 32上采樣到96 * 96,這是本例中圖檔的尺寸,與論文中的64 * 64的尺寸不一樣。最後一層采用Tanh将輸出圖檔的像素歸一化至-1~1,如果希望歸一化至0~1則需要使用Sigmoid。

接着我們來看判别器的網絡結構。

class NetD(nn.Module):
    """
    判别器定義
    """

    def __init__(self, opt):
        super(NetD, self).__init__()
        ndf = opt.ndf
        self.main = nn.Sequential(
            # 輸入 3 x 96 x 96
            nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 輸出 (ndf) x 32 x 32

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 輸出 (ndf*2) x 16 x 16

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 輸出 (ndf*4) x 8 x 8

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 輸出 (ndf*8) x 4 x 4

            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()  # 輸出一個數(機率)
        )

    def forward(self, input):
        return self.main(input).view(-1)

           

可以看出判别器和生成器的網絡結構幾乎是對稱的,從卷積核大小到padding、stride等設定,幾乎一模一樣。例如生成器的最後一個卷積層的尺度是(5,3,1),判别器的第一個卷積層的尺度也是(5,3,1)。另外,這裡需要注意的是生成器的激活函數用的是ReLU,而判别器使用的是LeakyReLU,二者并無本質差別,這裡的選擇更多是經驗總結。每一個樣本經過判别器後,輸出一個0~1的數,表示這個樣本是真圖檔的機率。

在開始寫訓練函數前,先來看看模型的配置參數。

class Config(object):
    data_path = 'data/'  # 資料集存放路徑
    num_workers = 4  # 多程序加載資料所用的程序數
    image_size = 96  # 圖檔尺寸
    batch_size = 256
    max_epoch = 200
    lr1 = 2e-4  # 生成器的學習率
    lr2 = 2e-4  # 判别器的學習率
    beta1 = 0.5  # Adam優化器的beta1參數
    gpu = True  # 是否使用GPU
    nz = 100  # 噪聲次元
    ngf = 64  # 生成器feature map數
    ndf = 64  # 判别器feature map數

    save_path = 'imgs/'  # 生成圖檔儲存路徑

    vis = True  # 是否使用visdom可視化
    env = 'GAN'  # visdom的env
    plot_every = 20  # 每間隔20 batch,visdom畫圖一次

    debug_file = '/tmp/debuggan'  # 存在該檔案則進入debug模式
    d_every = 1  # 每1個batch訓練一次判别器
    g_every = 5  # 每5個batch訓練一次生成器
    save_every = 10  # 沒10個epoch儲存一次模型
    netd_path = None  # 'checkpoints/netd_.pth' #預訓練模型
    netg_path = None  # 'checkpoints/netg_211.pth'

    # 隻測試不訓練
    gen_img = 'result.png'
    # 從512張生成的圖檔中儲存最好的64張
    gen_num = 64
    gen_search_num = 512
    gen_mean = 0  # 噪聲的均值
    gen_std = 1  # 噪聲的方差


opt = Config()
           

這些隻是模型的預設參數,還可以利用fire等工具通過指令行傳入,覆寫預設值。另外,我們也可以直接使用opt.attr,還可以利用IDE/IPython提供的自動補全功能,十分友善。這裡的超參數設定大多是照搬DCGAN論文的預設值,作者經過大量的實驗,發現這些參數能夠更快地訓練出一個不錯的模型。

當我們下載下傳完資料之後,需要将所有圖檔放在一個檔案夾,然後将該檔案夾移動至data目錄下(其確定data下沒有其他的檔案夾)。這種處理方式是為了能夠直接使用torchvision自帶的ImageFolder讀取圖檔,而不必自己寫Dataset。資料讀取與加載的代碼如下:

# 資料
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True
                                         )

           

可見,利用ImageFolder配合DataLoader加載圖檔十分友善。

在進行訓練之前,我們還需要定義幾個變量:模型、優化器、噪聲等。

# 網絡
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)


    # 定義優化器和損失
    optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss().to(device)

    # 真圖檔label為1,假圖檔label為0
    # noises為生成網絡的輸入
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.zeros(opt.batch_size).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

           

在加載預訓練模型時,最好指定map_location。因為如果程式之前在GPU上運作,那麼模型就會被存成torch.cuda.Tensor,這樣加載時會預設将資料加載至顯存。如果運作該程式的計算機中沒有GPU,加載就會報錯,故通過指定map_location将Tensor預設加載入記憶體中,待有需要時再移至顯存中。

下面開始訓練網絡,訓練步驟如下。

(1)訓練判别器

  • 固定生成器
  • 對于真圖檔,判别器的輸出機率值盡可能接近1
  • 對于生成器生成的假圖檔,判别器盡可能輸出0

(2)訓練生成器

  • 固定判别器
  • 生成器生成圖檔,盡可能讓判别器輸出1

(3)傳回第一步,循環交替訓練

epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            if ii % opt.d_every == 0:
                # 訓練判别器
                optimizer_d.zero_grad()
                ## 盡可能的把真圖檔判别為正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 盡可能把假圖檔判别為錯誤
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根據噪聲生成假圖
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            if ii % opt.g_every == 0:
                # 訓練生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可視化
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5, win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch+1) % opt.save_every == 0:
            # 儲存模型、圖檔
            tv.utils.save_image(fix_fake_imgs.data[:64], '%s/%s.png' % (opt.save_path, epoch), normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()
           

這裡需要注意以下幾點。

  • 訓練生成器時,無須調整判别器的參數;訓練判别器時,無須調整生成器的參數。
  • 在訓練判别器時,需要對生成器生成的圖檔用detach操作進行計算圖截斷,避免反向傳播将梯度傳到生成器中。因為在訓練判别器時我們不需要訓練生成器,也就不需要生成器的梯度。
  • 在訓練判别器時,需要反向傳播兩次,一次是希望把真圖檔判為1,一次是希望把假圖檔判為0。也可以将這兩者的資料放到一個batch中,進行一次前向傳播和一次反向傳播即可。但是人們發現,在一個batch中隻包含真圖檔或隻包含假圖檔的做法最好。
  • 對于假圖檔,在訓練判别器時,我們希望它輸出0;而在訓練生成器時,我們希望它輸出1.是以可以看到一對看似沖突的代碼 error_d_fake = criterion(output, fake_labels)和error_g = criterion(output, true_labels)。其實這也很好了解,判别器希望能夠把假圖檔判别為fake_label,而生成器則希望能把他判别為true_label,判别器和生成器互相對抗提升。

接下來就是一些可視化的代碼。每次可視化使用的噪聲都是固定的fix_noises,因為這樣便于我們比較對于相同的輸入,生成器生成的圖檔是如何一步步提升的。另外,由于我們對輸入的圖檔進行了歸一化處理(-1~1),在可視化時則需要将它還原成原來的scale(0~1)。

fix_fake_imgs = netg(fix_noises)
 vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake')
           

除此之外,還提供了一個函數,能夠加載預訓練好的模型,并利用噪聲随機生成圖檔。

@t.no_grad()
def generate(**kwargs):
    """
    随機生成動漫頭像,并根據netd的分數選擇較好的
    """
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    
    device=t.device('cuda') if opt.gpu else t.device('cpu')

    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
    noises = noises.to(device)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)


    # 生成圖檔,并計算圖檔在判别器的分數
    fake_img = netg(noises)
    scores = netd(fake_img).detach()

    # 挑選最好的某幾張
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 儲存圖檔
    tv.utils.save_image(t.stack(result), opt.gen_img, normalize=True, range=(-1, 1))

           

完整的代碼請參考本書的附帶樣例代碼chapter/AnimeGAN。參照README.MD中的指南配置環境,并準備好資料,而後用如下指令即可開始訓練:

python main.py train --gpu=True    # 使用GPU
                     --vis=True    # 使用visdom
                     --batch-size=256    # batch size
                     --max-epoch=200    # 訓練200個epoch
           

如果使用visdom的話,此時打開http://localhost:8097就能看到生成的圖像。

訓練完成後,我們可以利用生成網絡随機生成動漫圖像,輸入指令如下:

python main.py generate  --gen-img='result1.5w.png'
                         --gen-search-num=15000
           

7.3 實驗結果分析

實驗結果如下圖所示,分别是訓練1個、10個、20個、30個、40個、200個epoch之後神經網絡生成的動漫頭像(生成的圖像都在imgs檔案夾下)。需要注意的是,每次生成器輸入的噪聲都是一樣的,是以我們可以對比在相同的輸入下,生成圖檔的品質是如何慢慢改善的。

深度學習架構PyTorch入門與實踐:第七章 AI插畫師:生成對抗網絡

剛開始訓練的圖像比較模糊(1個epoch),但是可以看出圖像已經有面部輪廓。

繼續訓練10個epoch之後,生成的圖多了很多細節資訊,包括頭發、顔色等,但是總體還是模糊。

訓練20個epoch之後,細節繼續完善,包括頭發的紋理、眼睛的細節等,但還是有不少塗抹的痕迹。

訓練40個epoch時,已經能看出明顯的面部輪廓和細節,但還是有塗抹現象,并且有些細節不夠合理,例如眼睛一大一小,面部輪廓扭曲嚴重。

當訓練到200個epoch會後,圖檔的細節已經十分完善,線條更加流暢,輪廓更清晰,雖然還有一些不合理之處,但是已經有不少圖檔能夠以假亂真了。

類似的生成動漫頭像的項目還有《用DRGAN生成高清的動漫頭像》,效果如下圖所示。但遺憾的是,由于論文中使用的資料涉及版權問題,未能公開。這篇論文主要改進包括使用了更高品質的圖檔和更深、更複雜的模型。

深度學習架構PyTorch入門與實踐:第七章 AI插畫師:生成對抗網絡

本章講解的樣例程式還可以應用到不同的生成圖檔場景中,隻要将訓練圖檔改成其他類型的圖檔即可,例如LSUN房客圖檔集、MNIST手寫資料集或CIFAR10資料集等。事實上,上述模型還有很大的改進空間。在這裡,我們使用的全卷積網絡隻有四層,模型比較淺,而在ResNet的論文發表之後,也有不少研究者嘗試在GAN的網絡結構中引入Residual Block結構,并取得了不錯的視覺效果。感興趣的讀者可以嘗試将示例代碼中的單層卷積改為Residual Block,相信可以取得不錯的效果。

今年來,GAN的一個重大突破在于理論研究。論文《Towards Principled Methods for Training Generative Adversarial Networks》從理論的角度分析了GAN為何難以訓練,作者随後在另一篇論文《Wasserstein GAN》中針對性地提出了一個更好的解決方案。但是這篇論文在部分技術細節上的實作過于随意,是以随後又有人有針對性地提出了《Improved Training of Wasserstein GANs》,更好地訓練WGAN。後面兩篇論文分别用PyTorch和TensorFlow實作,代碼可以在GitHub上搜尋到。筆者當初也嘗試用100行左右的代碼實作了Wasserstein GAN,該興趣的讀者可以去了解。

随着GAN研究的逐漸成熟,人們也嘗試把GAN用于工業實際問題之中,而在衆多相關論文中,最令人深刻的就是《Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks》,論文中提出了一種新的GAN結構稱為CycleGAN。CycleGAN利用GAN實作風格遷移、黑白圖像彩色化,以及馬和斑馬互相轉化等,效果十分出衆。論文的作者用PyTorch實作了所有的代碼,并開源在GitHub上,感興趣的讀者可以自行查閱。

本章主要介紹GAN的基本原理,并帶領讀者利用GAN生成動漫頭像。GAN有許多變種,GitHub上有許多利用PyTorch實作的各種GAN,感興趣的讀者可以自行查閱。

繼續閱讀