天天看點

GAN學習指南:從原理入門到制作生成Demo,總共分幾步?

生成式對抗網絡(gan)是近年來大熱的深度學習模型。最近正好有空看了這方面的一些論文,跑了一個gan的代碼,于是寫了這篇文章來介紹一下gan。

本文主要分為三個部分:

介紹原始的gan的原理

同樣非常重要的dcgan的原理

如何在tensorflow跑dcgan的代碼,生成如題圖所示的動漫頭像,附送資料集哦 :-)

一、gan原理介紹

gan的基本原理其實非常簡單,這裡以生成圖檔為例進行說明。假設我們有兩個網絡,g(generator)和d(discriminator)。正如它的名字所暗示的那樣,它們的功能分别是:

g是一個生成圖檔的網絡,它接收一個随機的噪聲z,通過這個噪聲生成圖檔,記做g(z)。

d是一個判别網絡,判别一張圖檔是不是“真實的”。它的輸入參數是x,x代表一張圖檔,輸出d(x)代表x為真實圖檔的機率,如果為1,就代表100%是真實的圖檔,而輸出為0,就代表不可能是真實的圖檔。

在訓練過程中,生成網絡g的目标就是盡量生成真實的圖檔去欺騙判别網絡d。而d的目标就是盡量把g生成的圖檔和真實的圖檔分别開來。這樣,g和d構成了一個動态的“博弈過程”。

最後博弈的結果是什麼?在最理想的狀态下,g可以生成足以“以假亂真”的圖檔g(z)。對于d來說,它難以判定g生成的圖檔究竟是不是真實的,是以d(g(z)) = 0.5。

這樣我們的目的就達成了:我們得到了一個生成式的模型g,它可以用來生成圖檔。

以上隻是大緻說了一下gan的核心原理,如何用數學語言描述呢?這裡直接摘錄論文裡的公式:

GAN學習指南:從原理入門到制作生成Demo,總共分幾步?

簡單分析一下這個公式:

整個式子由兩項構成。x表示真實圖檔,z表示輸入g網絡的噪聲,而g(z)表示g網絡生成的圖檔。

d(x)表示d網絡判斷真實圖檔是否真實的機率(因為x就是真實的,是以對于d來說,這個值越接近1越好)。而d(g(z))是d網絡判斷g生成的圖檔的是否真實的機率。

g的目的:上面提到過,d(g(z))是d網絡判斷g生成的圖檔是否真實的機率,g應該希望自己生成的圖檔“越接近真實越好”。也就是說,g希望d(g(z))盡可能得大,這時v(d, g)會變小。是以我們看到式子的最前面的記号是min_g。

d的目的:d的能力越強,d(x)應該越大,d(g(x))應該越小。這時v(d,g)會變大。是以式子對于d來說是求最大(max_d)

下面這幅圖檔很好地描述了這個過程:

GAN學習指南:從原理入門到制作生成Demo,總共分幾步?

那麼如何用随機梯度下降法訓練d和g?論文中也給出了算法:

GAN學習指南:從原理入門到制作生成Demo,總共分幾步?

這裡紅框圈出的部分是我們要額外注意的。第一步我們訓練d,d是希望v(g, d)越大越好,是以是加上梯度(ascending)。第二步訓練g時,v(g, d)越小越好,是以是減去梯度(descending)。整個訓練過程交替進行。

二、dcgan原理介紹

dcgan的原理和gan是一樣的,這裡就不在贅述。它隻是把上述的g和d換成了兩個卷積神經網絡(cnn)。但不是直接換就可以了,dcgan對卷積神經網絡的結構做了一些改變,以提高樣本的品質和收斂的速度,這些改變有:

取消所有pooling層。g網絡中使用轉置卷積(transposed convolutional layer)進行上采樣,d網絡中用加入stride的卷積代替pooling。

在d和g中均使用batch normalization

去掉fc層,使網絡變為全卷積網絡

g網絡中使用relu作為激活函數,最後一層使用tanh

d網絡中使用leakyrelu作為激活函數

dcgan中的g網絡示意:

GAN學習指南:從原理入門到制作生成Demo,總共分幾步?

三、dcgan in tensorflow

好了,上面說了一通原理,下面說點有意思的實踐部分的内容。

GAN學習指南:從原理入門到制作生成Demo,總共分幾步?

這是個很有趣的實踐内容。可惜原文是用chainer做的,這個架構使用的人不多。下面我們就在tensorflow中複現這個結果。

1. 原始資料集的搜集

爬蟲代碼如下:

GAN學習指南:從原理入門到制作生成Demo,總共分幾步?

這個爬蟲大概跑了一天,爬下來12萬張圖檔,大概是這樣的:

GAN學習指南:從原理入門到制作生成Demo,總共分幾步?

可以看到這裡面的圖檔大多數比較雜亂,還不能直接作為資料訓練,我們需要用合适的工具,截取人物的頭像進行訓練。

2. 頭像截取

簡單包裝下代碼:

GAN學習指南:從原理入門到制作生成Demo,總共分幾步?

截取頭像後的人物資料:

GAN學習指南:從原理入門到制作生成Demo,總共分幾步?

這樣就可以用來訓練了!

3. 訓練

不過原始代碼中隻提供了有限的幾個資料庫,如何訓練自己的資料?在model.py中我們找到讀資料的幾行代碼:

if config.dataset == 'mnist':            data_x, data_y = self.load_mnist()        else:            data = glob(os.path.join("./data", config.dataset, "*.jpg"))

這樣讀資料的邏輯就很清楚了,我們在data檔案夾中再建立一個anime檔案夾,把圖檔直接放到這個檔案夾裡,運作時指定--dataset anime即可。

運作指令(參數含義:指定生成的圖檔的尺寸為48x48,我們圖檔的大小是96x96,跑300個epoch):

python main.py --image_size 96 --output_size 48 --dataset anime --is_crop true --is_train true --epoch 300

4. 結果

第1個epoch跑完(隻有一點點輪廓):

GAN學習指南:從原理入門到制作生成Demo,總共分幾步?

第5個epoch之後的結果:

GAN學習指南:從原理入門到制作生成Demo,總共分幾步?

第10個epoch:

GAN學習指南:從原理入門到制作生成Demo,總共分幾步?

200個epoch,仔細看有些圖檔确實是足以以假亂真的:

GAN學習指南:從原理入門到制作生成Demo,總共分幾步?

題圖是我從第300個epoch生成的。

四、總結和後續

簡單介紹了一下gan和dcgan的原理。以及如何使用tensorflow做一個簡單的生成圖檔的demo。

一些後續閱讀:

本文作者:何之源

繼續閱讀