天天看點

GAN網絡通俗解釋(圖畫版)

在本教程中,你将了解什麼是生成敵對網絡(GAN),并且在整個過程中不涉及負責的數學細節。之後,你還将學習如何編寫一個可以建立數字的簡單GAN!

GAN網絡通俗解釋(圖畫版)

什麼是GAN(插畫版介紹)

了解GAN的最簡單方法是通過一個簡單的比喻:

假設有一家商店它們從顧客那裡購買某些種類的葡萄酒,用于以後再銷售。

GAN網絡通俗解釋(圖畫版)

然而,有些惡意的顧客為了獲得金錢而出售假酒。在這種情況下,店主必須能夠區分假酒和正品葡萄酒。

GAN網絡通俗解釋(圖畫版)

你可以想象,最初,僞造者在嘗試出售假酒時可能會犯很多錯誤,并且店主很容易認定該酒不是真的。由于這些失敗,僞造者會繼續嘗試使用不同的技術來模拟真正的葡萄酒,最終才有可能成功。現在,僞造者知道某些技術已經超過了店主的認識假酒的能力,他可以開始進一步生産基于這些技術的假酒。

同時,店主可能會從其他店主或葡萄酒專家那裡得到一些回報,說明他擁有的一些葡萄酒不是原裝的。這意味着店主必須改善他是如何确定葡萄酒是僞造的還是真實的。僞造者的目标是制造與真實葡萄酒無法區分的葡萄酒,而店主的目标是準确地分辨葡萄酒是否真實。

這種來回的競争博弈就是GAN網絡背後的主要思想。

生成敵對網絡的組成部分

使用上面的例子,我們可以想出一個GAN的體系結構。

GAN網絡通俗解釋(圖畫版)

GAN網絡中有兩個主要元件:生成器和鑒别器。這個例子中的店主被稱為鑒别器網絡,并且通常是

卷積神經網絡

(因為GAN主要用于圖像任務),其主要功能是判斷圖像是真實的機率。

僞造者被稱為生成網絡,并且通常也是卷積神經網絡(具有

解卷積層

)。該網絡需要一些噪聲矢量并輸出圖像。在訓練生成網絡時,它會學習圖像的哪些區域進行改進/更改,以便鑒别器将難以将其生成的圖像與真實圖像區分開來。

生成網絡不斷生成更接近真實圖像的圖像,而辨識網絡試圖确定真實圖像和假圖像之間的差異。最終的目标是建立一個可生成與真實圖像無法區分的圖像的生成網絡。

一個簡單的Keras生成對抗網絡

現在你已經了解了GAN是什麼以及它們的主要組成部分,現在我們可以開始編寫一個非常簡單的代碼。本教程将使用

Keras

,如果你不熟悉此Python庫,則應在繼續之前閱讀翻譯小組其他文章。本教程是基于

這裡

開發的非常酷且易于了解的GAN。

你需要做的第一件事是通過以下方式安裝以下軟體包

pip

- keras
- matplotlib
- tensorflow
- tqdm           

你将

matplotlib

用于繪制

tensorflow

——

Keras後端庫,并用

tqdm

為每個時期(疊代)顯示一個奇特的進度條。

下一步是建立一個Python腳本。在這個腳本中,你首先需要導入你将要使用的所有子產品和函數,在使用它們時将給出每個解釋。

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Dense, Dropout
from keras.layers.advanced_activations import LeakyReLU
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import initializers           

你現在想要設定一些變量值:

# Let Keras know that we are using tensorflow as our backend engine
os.environ["KERAS_BACKEND"] = "tensorflow"
# To make sure that we can reproduce the experiment and get the same results
np.random.seed(10)
# The dimension of our random noise vector.
random_dim = 100           

在開始建構鑒别器和生成器之前,你應該首先收集并預處理資料。你将使用現在最流行的MNIST資料集,該資料集具有一組從0到9範圍内的單個數字的圖像。

GAN網絡通俗解釋(圖畫版)
def load_minst_data():
    # load the data
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    # normalize our inputs to be in the range[-1, 1] 
    x_train = (x_train.astype(np.float32) - 127.5)/127.5
    # convert x_train with a shape of (60000, 28, 28) to (60000, 784) so we have
    # 784 columns per row
    x_train = x_train.reshape(60000, 784)
    return (x_train, y_train, x_test, y_test)           

請注意,

mnist.load_data()

這個函數

是Keras的一部分,它允許你輕松将MNIST資料集導入你的工作區。

現在,你可以建立你的生成器和鑒别器網絡。你可以為這兩個網絡使用

Adam優化器

。對于生成器和鑒别器,你将建立一個帶有三個隐藏層的神經網絡,激活函數為

Leaky Relu

。你還應該為鑒别器添加

Drop-out圖層

,以提高其對未見圖像的魯棒性。

def get_optimizer():
    return Adam(lr=0.0002, beta_1=0.5) 
def get_generator(optimizer):
    generator = Sequential()
    generator.add(Dense(256, input_dim=random_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
    generator.add(LeakyReLU(0.2))
    generator.add(Dense(512))
    generator.add(LeakyReLU(0.2))
    generator.add(Dense(1024))
    generator.add(LeakyReLU(0.2))
    generator.add(Dense(784, activation='tanh'))
    generator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return generator
def get_discriminator(optimizer):
    discriminator = Sequential()
    discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))
    discriminator.add(Dense(512))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))
    discriminator.add(Dense(256))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))
    discriminator.add(Dense(1, activation='sigmoid'))
    discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return discriminator           

終于到了将生成器和鑒别器放在一起的時候了!

def get_gan_network(discriminator, random_dim, generator, optimizer):
    # We initially set trainable to False since we only want to train either the 
    # generator or discriminator at a time
    discriminator.trainable = False
    # gan input (noise) will be 100-dimensional vectors
    gan_input = Input(shape=(random_dim,))
    # the output of the generator (an image)
    x = generator(gan_input)
    # get the output of the discriminator (probability if the image is real or not)
    gan_output = discriminator(x)
    gan = Model(inputs=gan_input, outputs=gan_output)
    gan.compile(loss='binary_crossentropy', optimizer=optimizer)
    return gan           

為了保持整個過程的完整性,你可以建立一個功能,每20個紀元儲存你生成的圖像。由于這不是本教程的核心,是以你不需要完全了解該功能。

def plot_generated_images(epoch, generator, examples=100, dim=(10, 10), figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, random_dim])
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(examples, 28, 28)
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('gan_generated_image_epoch_%d.png' % epoch)           

你現在已經編碼了大部分網絡,剩下的就是訓練這個網絡,并看看你建立的圖像。

def train(epochs=1, batch_size=128):
    # Get the training and testing data
    x_train, y_train, x_test, y_test = load_minst_data()
    # Split the training data into batches of size 128
    batch_count = x_train.shape[0] / batch_size
    # Build our GAN netowrk
    adam = get_optimizer()
    generator = get_generator(adam)
    discriminator = get_discriminator(adam)
    gan = get_gan_network(discriminator, random_dim, generator, adam)

    for e in xrange(1, epochs+1):
        print '-'*15, 'Epoch %d' % e, '-'*15
        for _ in tqdm(xrange(batch_count)):
            # Get a random set of input noise and images
            noise = np.random.normal(0, 1, size=[batch_size, random_dim])
            image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]

            # Generate fake MNIST images
            generated_images = generator.predict(noise)
            X = np.concatenate([image_batch, generated_images])
            # Labels for generated and real data
            y_dis = np.zeros(2*batch_size)
            # One-sided label smoothing
            y_dis[:batch_size] = 0.9
            # Train discriminator
            discriminator.trainable = True
            discriminator.train_on_batch(X, y_dis)
            # Train generator
            noise = np.random.normal(0, 1, size=[batch_size, random_dim])
            y_gen = np.ones(batch_size)
            discriminator.trainable = False
            gan.train_on_batch(noise, y_gen)

        if e == 1 or e % 20 == 0:
            plot_generated_images(e, generator)

if __name__ == '__main__':
    train(400, 128)           

訓練400個紀元後,你可以檢視生成的圖像。檢視第一個紀元後産生的圖像,可以看到它沒有任何真實的結構,在40個紀元後檢視圖像,數字開始成形,最後,400個紀元後産生的圖像顯示出清晰的數字,盡管是一對夫婦仍然無法辨認。

GAN網絡通俗解釋(圖畫版)

1紀元(左)後的結果40個紀元後(中)的結果400個時代後的結果(右)

此代碼在CPU上每個紀元大約需要2分鐘,這是選擇此代碼的主要原因。你可以嘗試使用更多的紀元,并通過向生成器和鑒别器添加更多(和不同的)圖層。但是,當使用更複雜和更深的體系結構時,如果僅使用CPU,則運作時也會增加。

結論

恭喜,你已經完成了本教程的最後部分,你已經以直覺的方式學習生成敵對網絡(GAN)的基礎知識!

數十款阿裡雲産品限時折扣中,趕緊點選領劵開始雲上實踐吧!

本文由@

阿裡雲雲栖社群

組織翻譯。

文章原标題《 demystifying-generative-adversarial-networks 》, 譯者:虎說八道,審校:袁虎。 文章為簡譯,更為詳細的内容,請檢視 原文