天天看點

不到 200 行代碼,教你如何用 Keras 搭建生成對抗網絡(GAN)

不到 200 行代碼,教你如何用 Keras 搭建生成對抗網絡(GAN)

生成對抗網絡(generative adversarial networks,gan)最早由 ian goodfellow 在 2014 年提出,是目前深度學習領域最具潛力的研究成果之一。它的核心思想是:同時訓練兩個互相協作、同時又互相競争的深度神經網絡(一個稱為生成器 generator,另一個稱為判别器 discriminator)來處理無監督學習的相關問題。在訓練過程中,兩個網絡最終都要學習如何處理任務。

通常,我們會用下面這個例子來說明 gan 的原理:将警察視為判别器,制造假币的犯罪分子視為生成器。一開始,犯罪分子會首先向警察展示一張假币。警察識别出該假币,并向犯罪分子回報哪些地方是假的。接着,根據警察的回報,犯罪分子改進工藝,制作一張更逼真的假币給警方檢查。這時警方再回報,犯罪分子再改進工藝。不斷重複這一過程,直到警察識别不出真假,那麼模型就訓練成功了。

雖然 gan 的核心思想看起來非常簡單,但要搭建一個真正可用的 gan 網絡卻并不容易。因為畢竟在 gan 中有兩個互相耦合的深度神經網絡,同時對這兩個網絡進行梯度的反向傳播,也就比一般場景困難兩倍。

為此,本文将以深度卷積生成對抗網絡(deep convolutional gan,dcgan)為例,介紹如何基于 keras 2.0 架構,以 tensorflow 為後端,在 200 行代碼内搭建一個真實可用的 gan 模型,并以該模型為基礎自動生成 mnist 手寫體數字。

判别器的作用是判斷一個模型生成的圖像和真實圖像比,有多逼真。它的基本結構就是如下圖所示的卷積神經網絡(convolutional neural network,cnn)。對于 mnist 資料集來說,模型輸入是一個 28x28 像素的單通道圖像。sigmoid 函數的輸出值在 0-1 之間,表示圖像真實度的機率,其中 0 表示肯定是假的,1 表示肯定是真的。與典型的 cnn 結構相比,這裡去掉了層之間的 max-pooling,而是采用了步進卷積來進行下采樣。這裡每個 cnn 層都以 leakyrelu 為激活函數。而且為了防止過拟合和記憶效應,層之間的 dropout 值均被設定在 0.4-0.7 之間。具體在 keras 中的實作代碼如下。

不到 200 行代碼,教你如何用 Keras 搭建生成對抗網絡(GAN)

生成器的作用是合成假的圖像,其基本機構如下圖所示。圖中,我們使用了卷積的倒數,即轉置卷積(transposed convolution),從 100 維的噪聲(滿足 -1 至 1 之間的均勻分布)中生成了假圖像。如在 dcgan 模型中提到的那樣,去掉微步進卷積,這裡我們采用了模型前三層之間的上采樣來合成更逼真的手寫圖像。在層與層之間,我們采用了批量歸一化的方法來平穩化訓練過程。以 relu 函數為每一層結構之後的激活函數。最後一層 sigmoid 函數輸出最後的假圖像。第一層設定了 0.3-0.5 之間的 dropout 值來防止過拟合。具體代碼如下。

不到 200 行代碼,教你如何用 Keras 搭建生成對抗網絡(GAN)

下面我們生成真正的 gan 模型。如上所述,這裡我們需要搭建兩個模型:一個是判别器模型,代表警察;另一個是對抗模型,代表制造假币的犯罪分子。

判别器模型

下面代碼展示了如何在 keras 架構下生成判别器模型。上文定義的判别器是為模型訓練定義的損失函數。這裡由于判别器的輸出為 sigmoid 函數,是以采用了二進制交叉熵為損失函數。在這種情況下,以 rmsprop 作為優化算法可以生成比 adam 更逼真的假圖像。這裡我們将學習率設定在 0.0008,同時還設定了權值衰減和clipvalue等參數來穩定後期的訓練過程。如果你需要調節學習率,那麼也必須同步調節其他相關參數。

對抗模型

如圖所示,對抗模型的基本結構是判别器和生成器的疊加。生成器試圖騙過判别器,同時從其回報中提升自己。如下代碼中示範了如何基于 keras 架構實作這一部分功能。其中,除了學習速率的降低和相對權值衰減之外,訓練參數與判别器模型中的訓練參數完全相同。

不到 200 行代碼,教你如何用 Keras 搭建生成對抗網絡(GAN)

訓練

搭好模型之後,訓練是最難實作的部分。這裡我們首先用真實圖像和假圖像對判别器模型單獨進行訓練,以判斷其正确性。接着,對判别器模型和對抗模型輪流展開訓練。如下圖展示了判别器模型訓練的基本流程。在 keras 架構下的實作代碼如下所示。

不到 200 行代碼,教你如何用 Keras 搭建生成對抗網絡(GAN)

訓練過程中需要非常耐心,這裡列出一些常見問題和解決方案:

問題1:最終生成的圖像噪點太多。

解決:嘗試在判别器和生成器模型上引入 dropout,一般更小的 dropout 值(0.3-0.6)可以産生更逼真的圖像。

問題2:判别器的損失函數迅速收斂為零,導緻發生器無法訓練。

解決:不要對判别器進行預訓練。而是調整學習率,使判别器的學習率大于對抗模型的學習率。也可以嘗試對生成器換一個不同的訓練噪聲樣本。

問題3:生成器輸出的圖像仍然看起來像噪聲。

解決:檢查激活函數、批量歸一化和 dropout 的應用流程是否正确。

問題4:如何确定正确的模型/訓練參數。

解決:嘗試從一些已經發表的論文或代碼中找到參考,調試時每次隻調整一個參數。在進行 2000 步以上的訓練時,注意觀察在 500 或 1000 步左右參數值調整的效果。

下圖展示了在訓練過程中,整個模型的輸出變化情況。可以看到,gan 在自己學習如何生成手寫體數字。

不到 200 行代碼,教你如何用 Keras 搭建生成對抗網絡(GAN)

完整代碼位址:

本文作者:恒亮

繼續閱讀