天天看點

5分鐘入門GANS:原了解釋和keras代碼實作

本篇文章包含以下内容

  1. 介紹
  2. 曆史
  3. 直覺解釋
  4. 訓練過程
  5. GAN在MNIST資料集上的KERAS實作

介紹

生成式對抗網絡通常也稱為GANs,用于生成圖像而不需要很少或沒有輸入。GANs允許我們生成由神經網絡生成的圖像。在我們深入讨論這個理論之前,我想向您展示GANs建構您興奮感的能力。把馬變成斑馬(反之亦然)。

5分鐘入門GANS:原了解釋和keras代碼實作
5分鐘入門GANS:原了解釋和keras代碼實作

曆史

生成式對抗網絡(GANs)是由Ian Goodfellow (GANs的GAN Father)等人于2014年在其題為“生成式對抗網絡”的論文中提出的。它是一種可替代的自适應變分編碼器(VAEs)學習圖像的潛在空間,以生成合成圖像。它的目的是創造逼真的人工圖像,幾乎無法與真實的圖像區分。

GAN的直覺解釋

生成器和鑒别器網絡:

生成器網絡的目的是将随機圖像初始化并解碼成一個合成圖像。

鑒别器網絡的目的是擷取這個輸入,并預測這個圖像是來自真實的資料集還是合成的。

正如我們剛才看到的,這實際上就是GANs,兩個互相競争的對抗網絡。

GAN的訓練過程

GANS的訓練是出了名的困難。在CNN中,我們使用梯度下降來改變權重以減少損失。

然而,在GANs中,每一次重量的變化都會改變整個動态系統的平衡。

在GAN的網絡中,我們不是在尋求将損失最小化,而是在我們對立的兩個網絡之間找到一種平衡。

我們将過程總結如下

  1. 輸入随機生成的噪聲圖像到我們的生成器網絡中生成樣本圖像。
  2. 我們從真實資料中提取一些樣本圖像,并将其與一些生成的圖像混合在一起。
  3. 将這些混合圖像輸入到我們的鑒别器中,鑒别器将對這個混合集進行訓練并相應地更新它的權重。
  4. 然後我們制作更多的假圖像,并将它們輸入到鑒别器中,但是我們将它們标記為真實的。這樣做是為了訓練生成器。我們在這個階段當機了鑒别器的權值(鑒别器學習停止),并且我們使用來自鑒别器的回報來更新生成器的權值。這就是我們如何教我們的生成器(制作更好的合成圖像)和鑒别器更好地識别赝品的方法。

流程圖如下

5分鐘入門GANS:原了解釋和keras代碼實作

對于本文,我們将使用MNIST資料集生成手寫數字。GAN的架構是:

5分鐘入門GANS:原了解釋和keras代碼實作

使用KERAS實作GANS

首先,我們加載所有必要的庫

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Reshape, Dense, Dropout, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Convolution2D, UpSampling2D
from keras.layers.normalization import BatchNormalization
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import backend as K
from keras import initializers

K.set_image_dim_ordering('th')

# Deterministic output.
# Tired of seeing the same results every time? Remove the line below.
np.random.seed(1000)

# The results are a little better when the dimensionality of the random vector is only 10.
# The dimensionality has been left at 100 for consistency with other GAN implementations.
randomDim = 100           

複制

現在我們加載資料集。這裡使用MNIST資料集,是以不需要單獨下載下傳和處理。

(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape(60000, 784)           

複制

接下來,我們定義生成器和鑒别器的結構

# Optimizer
adam = Adam(lr=0.0002, beta_1=0.5)#generator
generator = Sequential()
generator.add(Dense(256, input_dim=randomDim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=adam)#discriminator
discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=adam)           

複制

現在我們把發生器和鑒别器結合起來同時訓練。

# Combined network
discriminator.trainable = False
ganInput = Input(shape=(randomDim,))
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(inputs=ganInput, outputs=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=adam)

dLosses = []
gLosses = []           

複制

三個函數,每20個epoch繪制并儲存結果,并儲存模型。

# Plot the loss from each batch
def plotLoss(epoch):
    plt.figure(figsize=(10, 8))
    plt.plot(dLosses, label='Discriminitive loss')
    plt.plot(gLosses, label='Generative loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('images/gan_loss_epoch_%d.png' % epoch)

# Create a wall of generated MNIST images
def plotGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, randomDim])
    generatedImages = generator.predict(noise)
    generatedImages = generatedImages.reshape(examples, 28, 28)

    plt.figure(figsize=figsize)
    for i in range(generatedImages.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('images/gan_generated_image_epoch_%d.png' % epoch)

# Save the generator and discriminator networks (and weights) for later use
def saveModels(epoch):
    generator.save('models/gan_generator_epoch_%d.h5' % epoch)
    discriminator.save('models/gan_discriminator_epoch_%d.h5' % epoch)           

複制

訓練函數

def train(epochs=1, batchSize=128):
    batchCount = X_train.shape[0] / batchSize
    print 'Epochs:', epochs
    print 'Batch size:', batchSize
    print 'Batches per epoch:', batchCount

    for e in xrange(1, epochs+1):
        print '-'*15, 'Epoch %d' % e, '-'*15
        for _ in tqdm(xrange(batchCount)):
            # Get a random set of input noise and images
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]

            # Generate fake MNIST images
            generatedImages = generator.predict(noise)
            # print np.shape(imageBatch), np.shape(generatedImages)
            X = np.concatenate([imageBatch, generatedImages])

            # Labels for generated and real data
            yDis = np.zeros(2*batchSize)
            # One-sided label smoothing
            yDis[:batchSize] = 0.9

            # Train discriminator
            discriminator.trainable = True
            dloss = discriminator.train_on_batch(X, yDis)

            # Train generator
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            yGen = np.ones(batchSize)
            discriminator.trainable = False
            gloss = gan.train_on_batch(noise, yGen)

        # Store loss of most recent batch from this epoch
        dLosses.append(dloss)
        gLosses.append(gloss)

        if e == 1 or e % 20 == 0:
            plotGeneratedImages(e)
            saveModels(e)

    # Plot losses from every epoch
    plotLoss(e)           

複制

至此一個簡單的GAN已經完成了,完整的代碼在這裡找到

https://github.com/bhaveshgoyal27/mediumblogs/blob/master/Keras_MNIST_GAN.py

作者:Bhavesh Goyal

deephub翻譯組