天天看點

機器學習分享——手把手帶你寫一個GAN

GAN

今天讓我們從這幾方面來探索:

  • GAN

    能用來做什麼
  • GAN

    的原理
  • GAN

    的代碼實作

用途

GAN

自2014年誕生以來, 就一直備受關注, 著名的應用也随即産出, 比如比較著名的GAN的應用有Pix2Pix,CycleGAN等, 大家也将它用于各個地方。

  1. 缺失/模糊像素的補充
  2. 圖檔修複
  3. ……

我覺得還有一個比較重要的用途, 很多人都會缺少資料集, 那麼就通過

GAN

去生成資料集了, 通過調節部分參數來進行資料集的産生的相似度。

原理

GAN的基本原理其實非常簡單,這裡以生成圖檔為例進行說明。假設我們有兩個網絡,

G(Generator)

D(Discriminator)

。正如它的名字所暗示的那樣, 它們的功能分别是:

G

是一個生成圖檔的網絡, 它接收一個随機的噪聲(随機生成的圖檔)

z

, 通過這個噪聲生成圖檔,記做

G(z)

.

D

是一個判别網絡, 判别一張圖檔是不是“真實的”。它的輸入參數是

x

,

x

代表一張圖檔,輸出

D(x)

代表

x

為真實圖檔的機率,如果為

>0.5

,就代表是真實(相似)的圖檔,反之,就代表不是真實的圖檔。

機器學習分享——手把手帶你寫一個GAN

我們通過一個假産品宣傳的例子來了解:

首先, 我們來定義一下角色.

  1. 進行宣傳的

    '專家'(生成網絡)

  2. 正在聽講的

    '我們'(判别網絡)

'專家'

的手裡面拿着一堆高仿的産品, 正在進行宣講, 我們是熟知真品的相關資訊的, 通過去對比兩個産品之間的差距, 來判斷是赝品的可能性.

這時, 我們就可以引出來一個概念, 如果

'專家'

團隊比較厲害, 完美的仿造了我們的判斷依據, 比如說産出方, 發明日期, 說明文等等, 那麼我們就會覺得他是真的, 那麼他就是一個好的

生成網絡

, 反之, 我們會判斷他是赝品.

我們(判别網絡)

出發, 我們的判斷條件越苛刻, 赝品和真品之間的差距會越來越小, 這樣的最後的産出就是真假難分, 完全被模仿了.

相關資源

深層的原理推薦大家可以去閱讀Generative Adversarial Networks這篇論文

損失函數等相關細節我們在實作裡介紹。

實作

接下來我們就以

mnist

來實作

GAN

吧.

  1. 首先, 我們先下載下傳資料集.
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./Dataset/datasets/MNIST_data', one_hot=False)
           

我們通過

tensorflow

去下載下傳

mnist

的資料集, 然後加載到記憶體,

one-hot

參數決定我們的

label

是否要經過編碼(mnist資料集是有10個類别), 但是我們判别網絡是對比真實的和生成的之間的差別以及相似的可能性, 是以不需要執行

one-hot

編碼了.

這裡讀取出來的圖檔已經歸一化到[0, 1]之間了.

  1. 俗話說, 知己知彼, 百戰百勝, 那我們拿到資料集, 就先來看看它長什麼樣.
def show_images(images):
    images = np.reshape(images, [images.shape[0], -1])  # images reshape to (batch_size, D)
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg,sqrtimg]))
           
機器學習分享——手把手帶你寫一個GAN

這裡有一個小問題, 如果是在

Notebook

中執行, 記得加上這句話, 否則需要執行兩次才會繪制.

%matplotlib inline
           
  1. 資料看過了, 我們該對它進行一定的處理了, 這裡我們隻是将資料縮放到[-1, 1]之間.
def preprocess_img(x):
    return 2 * x - 1.0

def deprocess_img(x):
    return (x + 1.0) / 2.0
           
  1. 資料處理完了, 接下來我們要開始搭模組化型了, 這一部分我們有兩個模型, 一個生成網絡, 一個判别網絡.

生成網絡

def generator(z):

    with tf.variable_scope("generator"):

        fc1 = tf.layers.dense(inputs=z, units=1024, activation=tf.nn.relu)
        bn1 = tf.layers.batch_normalization(inputs=fc1, training=True)
        fc2 = tf.layers.dense(inputs=bn1, units=7*7*128, activation=tf.nn.relu)
        bn2 = tf.layers.batch_normalization(inputs=fc2, training=True)
        reshaped = tf.reshape(bn2, shape=[-1, 7, 7, 128])
        conv_transpose1 = tf.layers.conv2d_transpose(inputs=reshaped, filters=64, kernel_size=4, strides=2, activation=tf.nn.relu,
                                                    padding='same')
        bn3 = tf.layers.batch_normalization(inputs=conv_transpose1, training=True)
        conv_transpose2 = tf.layers.conv2d_transpose(inputs=bn3, filters=1, kernel_size=4, strides=2, activation=tf.nn.tanh,
                                        padding='same')

        img = tf.reshape(conv_transpose2, shape=[-1, 784])
        return img
           

判别網絡

def discriminator(x):

    with tf.variable_scope("discriminator"):

        unflatten = tf.reshape(x, shape=[-1, 28, 28, 1])
        conv1 = tf.layers.conv2d(inputs=unflatten, kernel_size=5, strides=1, filters=32 ,activation=leaky_relu)
        maxpool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=2, strides=2)
        conv2 = tf.layers.conv2d(inputs=maxpool1, kernel_size=5, strides=1, filters=64,activation=leaky_relu)
        maxpool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=2, strides=2)
        flatten = tf.reshape(maxpool2, shape=[-1, 1024])
        fc1 = tf.layers.dense(inputs=flatten, units=1024, activation=leaky_relu)
        logits = tf.layers.dense(inputs=fc1, units=1)

        return logits
           

激活函數我們使用了

leaky_relu

, 他的代碼實作是

def leaky_relu(x, alpha=0.01):
    activation = tf.maximum(x,alpha*x)
    return activation
           

它和

relu

的差別就是, 小于0的值也會給與一點小的權重進行保留.

  1. 建立損失函數
def gan_loss(logits_real, logits_fake):

    # Target label vector for generator loss and used in discriminator loss.
    true_labels = tf.ones_like(logits_fake)

    # DISCRIMINATOR loss has 2 parts: how well it classifies real images and how well it
    # classifies fake images.
    real_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_real, labels=true_labels)
    fake_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=1-true_labels)

    # Combine and average losses over the batch
    D_loss = real_image_loss + fake_image_loss
    D_loss = tf.reduce_mean(D_loss)

    # GENERATOR is trying to make the discriminator output 1 for all its images.
    # So we use our target label vector of ones for computing generator loss.
    G_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=true_labels)

    # Average generator loss over the batch.
    G_loss = tf.reduce_mean(G_loss)

    return D_loss, G_loss
           

損失我們分為兩部分, 一部分是生成網絡的, 一部分是判别網絡的.

生成網絡的損失定義為, 生成圖像的類别與真實标簽(全是1)的交叉熵損失.

l o s s g e n e r a t e = − ∑ i n ( Y i l o g G i + ( 1 − Y i ) l o g G i ) loss_{generate} = -\sum_i^n(Y_ilogG_i + (1 - Y_i)logG_i) lossgenerate​=−i∑n​(Yi​logGi​+(1−Yi​)logGi​)

判别網絡的損失定義為, 我們将真實圖檔的标簽設定為1, 生成圖檔的标簽設定為0, 然後由真實圖檔的輸出以及生成圖檔的輸出的交叉熵損失和.

T: True

,

G: Generate

生成損失

l o s s G = − ∑ i n ( Y i l o g G i + ( 1 − Y i ) l o g G i ) loss_{G} = -\sum_i^n(Y_ilogG_i + (1 - Y_i)logG_i) lossG​=−i∑n​(Yi​logGi​+(1−Yi​)logGi​)

真實圖檔損失

l o s s T = − ∑ i n ( Y i l o g T i + ( 1 − Y i ) l o g T i ) loss_{T} = -\sum_i^n(Y_ilogT_i + (1 - Y_i)logT_i) lossT​=−i∑n​(Yi​logTi​+(1−Yi​)logTi​)

總損失

L o s s D = l o s s G + l o s s T Loss_{D} = loss_{G} + loss_{T} LossD​=lossG​+lossT​

  1. 訓練
def run_a_gan(sess, G_train_step, G_loss, D_train_step, D_loss, G_extra_step, D_extra_step,\
              show_every=250, print_every=50, batch_size=128, num_epoch=10):
    # compute the number of iterations we need
    max_iter = int(mnist.train.num_examples*num_epoch/batch_size)
    for it in range(max_iter):
        # every show often, show a sample result
        if it % show_every == 0:
            samples = sess.run(G_sample)
            fig = show_images(samples[:16])
            plt.show()
            print()
        # run a batch of data through the network
        minibatch,minbatch_y = mnist.train.next_batch(batch_size)
        _1, D_loss_curr = sess.run([D_train_step, D_loss], feed_dict={x: minibatch})
        _2, G_loss_curr = sess.run([G_train_step, G_loss])
        if it % show_every == 0:
            print(_1,_2)
        # print loss every so often.
        # We want to make sure D_loss doesn't go to 0
        if it % print_every == 0:
            print('Iter: {}, D: {:.4}, G:{:.4}'.format(it,D_loss_curr,G_loss_curr))
    print('Final images')
    samples = sess.run(G_sample)

    fig = show_images(samples[:16])
    plt.show()
           

這裡就是開始訓練了, 并展示訓練的結果.

  1. 檢視結果

剛開始的時候, 還沒學會怎麼模仿:

機器學習分享——手把手帶你寫一個GAN

經過學習改進:

機器學習分享——手把手帶你寫一個GAN
機器學習分享——手把手帶你寫一個GAN

項目位址

檢視源碼(請在PC端打開)

聲明

該文章參考了知乎。

——————————————————————————————————

Mo (網址:momodel.cn )是一個支援 Python 的人工智能模組化平台,能幫助你快速開發訓練并部署 AI 應用。期待你的加入。

繼續閱讀