天天看點

GAN生成對抗網絡合集(三):InfoGAN和ACGAN-指定類别生成模拟樣本的GAN(附代碼)

1 InfoGAN-帶有隐含資訊的GAN

       InfoGAN是一種把資訊論與GAN相融合的神經網絡,能夠使網絡具有資訊解讀功能。

       GAN的生成器在建構樣本時使用了任意的噪聲向量x’,并從低維的噪聲資料x’中還原出來高維的樣本資料。這說明資料x’中含有具有與樣本相同的特征。

       由于随意使用的噪聲都能還原出高維樣本資料,表明噪聲中的特征資料部分是與無用的資料部分高度地糾纏在一起的,即我們能夠知道噪聲中含有有用特征,但無法知道哪些是有用特征。

       InfoGAN是GAN模型的一種改進,是一種能夠學習樣本中的關鍵次元資訊的GAN,即對生成樣本的噪音進行了細化。先來看它的結構,相比對抗自編碼,InfoGAN的思路正好相反,InfoGAN是先固定标準高斯分布作為網絡輸入,再慢慢調整網絡輸出去比對複雜樣本分布。

GAN生成對抗網絡合集(三):InfoGAN和ACGAN-指定類别生成模拟樣本的GAN(附代碼)

                                                                                                         圖3.1 InfoGAN模型

       如圖3.1所示,InfoGAN生成器是從标準高斯分布中随機采樣來作為輸入,生成模拟樣本,解碼器是将生成器輸出的模拟樣本還原回生成器輸入的随機數中的一部分,判别器是将樣本作為輸入來區分真假樣本。

       InfoGAN的理論思想是将輸入的随機标準高斯分布當成噪音資料,并将噪音分為兩類,第一類是不可壓縮的噪音Z,第二類是可解釋性的資訊C。假設在一個樣本中,決定其本身的隻有少量重要的次元,那麼大多數的次元是可以忽略的。而這裡的解碼器可以更形象地叫成重構器,即通過重構一部分輸入的特征來确定與樣本互資訊的那些次元。最終被找到的次元可以代替原始樣本的特征(類似PCA算法中的主成份),實作降維、解耦的效果。

2 AC-GAN-帶有輔助分類資訊的GAN

       AC-GAN(Auxiliary Classifier GAN),即在判别器discriminator中再輸出相應的分類機率,然後增加輸出的分類與真實分類的損失計算,使生成的模拟資料與其所屬的class一一對應。一般來講,AC-GAN可以屬于InfoGAN的一部分,class資訊可以作為InfoGAN中的潛在資訊,隻不過這部分資訊可以使用半監督方式來學習。

GAN生成對抗網絡合集(三):InfoGAN和ACGAN-指定類别生成模拟樣本的GAN(附代碼)
GAN生成對抗網絡合集(三):InfoGAN和ACGAN-指定類别生成模拟樣本的GAN(附代碼)
GAN生成對抗網絡合集(三):InfoGAN和ACGAN-指定類别生成模拟樣本的GAN(附代碼)

3 代碼

       首先明确,GAN的代碼沒有目标檢測的複雜,以一個目标檢測程式demo的篇幅就涵蓋了GAN的資料輸入、訓練、定義網絡結構和參數、loss函數和優化器以及可視化部分。

       還可以學習到的是,GAN基本除開兩個大的網絡架構G和D以外,就是加各種限制(分類資訊、隐含資訊等)用以生成想要的資料。

       下面是代碼實作學習MINST資料特征,生成以假亂真的MNIST模拟樣本,并發現内部潛在的特征資訊。

GAN生成對抗網絡合集(三):InfoGAN和ACGAN-指定類别生成模拟樣本的GAN(附代碼)

代碼總綱:

  1. 加載資料集;
  2. 定義G和D;
  3. 定義網絡模型的參數、輸入輸出、中間過程(經過G/D)的輸入輸出;
  4. 定義loss函數和優化器;
  5. 訓練和測試(套循環);
  6. 可視化

3.1 加載資料集、引入頭檔案

       MNIST資料集下載下傳到相應的位址,其加載方式是固定的。

# -*- coding: utf-8 -*-
##################################################################
#  1.引入頭檔案并加載mnist資料
##################################################################
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from scipy.stats import norm
import tensorflow.contrib.slim as slim

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/media/S318080208/py_pictures/minist/")  # ,one_hot=True)


tf.reset_default_graph()  # 用于清除預設圖形堆棧并重置全局預設圖形

           

3.2 定義G和D

  • 生成器G

    通過“兩個全連接配接+兩個反卷積(轉置卷積slim.conv2d_transpose)”模拟樣本的生成,每一層都有BN(批量歸一化)處理。

  • 判别器D

    判别器中有使用leaky_relu函數,其餘的在slim庫裡有,不用重新定義;

    判别器也是由“兩次卷積+兩次全連接配接”組成。生成的資料可以分别連接配接不同的輸出層産生不同的結果,其中1維的輸出層産生判别結果1或0,10維的輸出層産生分類結果,2維輸出層産生隐含次元資訊。

##################################################################
#  2.定義生成器與判别器
##################################################################
def generator(x):  # 生成器函數 : 兩個全連接配接+兩個反卷積模拟樣本的生成,每一層都有BN(批量歸一化)處理
    reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) > 0   # 确認該變量作用域沒有變量
    # print (x.get_shape())
    with tf.variable_scope('generator', reuse=reuse):
        x = slim.fully_connected(x, 1024)
        # print(x)
        x = slim.batch_norm(x, activation_fn=tf.nn.relu)
        x = slim.fully_connected(x, 7*7*128)
        x = slim.batch_norm(x, activation_fn=tf.nn.relu)
        x = tf.reshape(x, [-1, 7, 7, 128])
        # print ('22', tf.tensor.get_shape())
        x = slim.conv2d_transpose(x, 64, kernel_size=[4, 4], stride=2, activation_fn = None)
        # print ('gen',x.get_shape())
        x = slim.batch_norm(x, activation_fn=tf.nn.relu)
        z = slim.conv2d_transpose(x, 1, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.sigmoid)
        # print ('genz',z.get_shape())
    return z


def leaky_relu(x):
     return tf.where(tf.greater(x, 0), x, 0.01 * x)


def discriminator(x, num_classes=10, num_cont=2):  # 判别器函數 : 兩次卷積,再接兩次全連接配接
    reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0
    # print (reuse)
    # print (x.get_shape())
    with tf.variable_scope('discriminator', reuse=reuse):
        x = tf.reshape(x, shape=[-1, 28, 28, 1])
        x = slim.conv2d(x, num_outputs=64, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
        x = slim.conv2d(x, num_outputs=128, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
        # print ("conv2d",x.get_shape())
        x = slim.flatten(x)  # 輸入扁平化
        shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn=leaky_relu)
        recog_shared = slim.fully_connected(shared_tensor, num_outputs=128, activation_fn=leaky_relu)

        # 生成的資料可以分别連接配接不同的輸出層産生不同的結果
        # 1維的輸出層産生判别結果1或是0
        disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=None)
        disc = tf.squeeze(disc, -1)
        # print ("disc",disc.get_shape()) # 0 or 1

        # 10維的輸出層産生分類結果 (樣本标簽)
        recog_cat = slim.fully_connected(recog_shared, num_outputs=num_classes, activation_fn=None)

        # 2維輸出層産生重構造的隐含次元資訊
        recog_cont = slim.fully_connected(recog_shared, num_outputs=num_cont, activation_fn=tf.nn.sigmoid)
    return disc, recog_cat, recog_cont
           

3.3 定義網絡模型 輸入/輸出/中間參數

       輸入進生成器的是兩個噪聲資料(一般噪聲随機向量z_rand 38列 / 隐含資訊限制z_con 2列)和分類标簽labels的one_hot編碼 10 列。生成模拟樣本,然後将模拟樣本gen和真實樣本x分别輸入到判别器中,生成判别結果dis_fake/樣本标簽class_fake/重構造的隐含資訊con_fake 以及 dis_real/class_real/ _ 。

注:隐含資訊在這裡是指字型的粗細和傾斜資訊。它不由我們控制,比如我想讓字型擁有這兩個資訊的特征生成,就給他們兩個隐含資訊;如果沒有這種特征生成,就多加幾個隐含資訊,假如加10個隐含資訊,看裡面有沒有能控制的,多餘的就當是随機變量。如果再都沒有,就說明這個太複雜了,學習不了(個人了解)。
##################################################################
#  3.定義網絡模型 : 定義 參數/輸入/輸出/中間過程(經過G/D)的輸入輸出
##################################################################
batch_size = 10   # 擷取樣本的批次大小32
classes_dim = 10  # 10 classes
con_dim = 2       # 隐含資訊變量的次元, 應節點為z_con
rand_dim = 38     # 一般噪聲的次元, 應節點為z_rand, 二者都是符合标準高斯分布的随機數。
n_input = 784     # 28 * 28


x = tf.placeholder(tf.float32, [None, n_input])     # x為輸入真實圖檔images
y = tf.placeholder(tf.int32, [None])                # y為真實标簽labels

z_con = tf.random_normal((batch_size, con_dim))  # 2列
z_rand = tf.random_normal((batch_size, rand_dim))  # 38列
z = tf.concat(axis=1, values=[tf.one_hot(y, depth=classes_dim), z_con, z_rand])  # 50列 shape = (10, 50)
gen = generator(z)  # shape = (10, 28, 28, 1)
genout= tf.squeeze(gen, -1)  # shape = (10, 28, 28)


# labels for discriminator
y_real = tf.ones(batch_size)  # 真
y_fake = tf.zeros(batch_size)  # 假

# 判别器
disc_real, class_real, _ = discriminator(x)
disc_fake, class_fake, con_fake = discriminator(gen)
pred_class = tf.argmax(class_fake, dimension=1)
           

3.4 定義損失函數和優化器

       判别器D的損失函數有兩個:真實輸入的結果loss_d_r和模拟輸入的結果loss_d_f。二者結合為loss_d;(輸入真實樣本,判别為真/輸入模拟樣本,判别為假)

       生成器G的損失函數是想要“以假亂真”,自己輸出的模拟資料,讓它在D中判别為真,loss值為loss_g;

       還要定義網絡中共有的loss值:真實的标簽與輸入模拟樣本判别出的标簽loss_cf、真實的标簽與輸入真實樣本判别的标簽loss_cr、隐含資訊的重構誤差loss_con。

       之後用AdamOptimizer分别優化G和D。其中用了一個技巧,将D的學習率設小0.0001,将G的學習率設大0.001,可以讓G有更快的進化速度來模拟真實資料。

##################################################################
#  4.定義損失函數和優化器
##################################################################
# 判别器 loss
loss_d_r = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real, labels=y_real))
loss_d_f = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=y_fake))
loss_d = (loss_d_r + loss_d_f) / 2
# print ('loss_d', loss_d.get_shape())

# 生成器 loss
loss_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=y_real))

# categorical factor loss 分類因素損失
loss_cf = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_fake, labels=y))
loss_cr = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_real, labels=y))
loss_c = (loss_cf + loss_cr) / 2


# continuous factor loss 隐含資訊變量的損失
loss_con = tf.reduce_mean(tf.square(con_fake-z_con))

# 獲得各個網絡中各自的訓練參數清單
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]

# 優化器
disc_global_step = tf.Variable(0, trainable=False)
gen_global_step = tf.Variable(0, trainable=False)

train_disc = tf.train.AdamOptimizer(0.0001).minimize(loss_d + loss_c + loss_con, var_list=d_vars, global_step=disc_global_step)
train_gen = tf.train.AdamOptimizer(0.001).minimize(loss_g + loss_c + loss_con, var_list=g_vars, global_step=gen_global_step)
           

       所謂的AC-GAN就是将 loss_cr 加入到 loss_c 中。如果沒有 loss_cr,令 loss_c = loss_c,對于網絡生成模拟資料是不影響的,但是會損失真實分類與模拟資料間的對應關系(未告知分類資訊)(影響後果見可視化部分)。

3.5 訓練與測試

       建立 session,在循環裡使用 run 來運作前面建構的兩個優化器。測試部分分别使用 loss_d 和 loss_g 的 eval 完成。

       整個資料集運作3次後,判别誤差在0.5左右,基本可以認為是對真假資料無法分辨。

##################################################################
#  5.訓練與測試
#  建立session,循環中使用run來運作兩個優化器
##################################################################
training_epochs = 3
display_step = 1

config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.4

with tf.Session(config=config) as sess:
    sess.run(tf.global_variables_initializer())
    
    for epoch in range(training_epochs):
        avg_cost = 0.
        total_batch = int(mnist.train.num_examples/batch_size)  # 5500

        # 周遊全部資料集
        for i in range(total_batch):

            batch_xs, batch_ys = mnist.train.next_batch(batch_size)  # 取資料x:images, y:labels
            feeds = {x: batch_xs, y: batch_ys}

            # Fit training using batch data
            # 輸入資料,運作優化器
            l_disc, _, l_d_step = sess.run([loss_d, train_disc, disc_global_step], feeds)
            l_gen, _, l_g_step = sess.run([loss_g, train_gen, gen_global_step], feeds)

        # 顯示訓練中的詳細資訊
        if epoch % display_step == 0:
            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f} ".format(l_disc), l_gen)

    print("完成!")
    
    # 測試
    print("Result: loss_d = ", loss_d.eval({x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size]}),
          "\n        loss_g = ", loss_g.eval({x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size]}))

           

測試結果如下:

GAN生成對抗網絡合集(三):InfoGAN和ACGAN-指定類别生成模拟樣本的GAN(附代碼)

3.6 可視化

       可視化部分分為兩部分,一部分是對原圖檔和對應的模拟資料圖檔進行plt。另一部分是利用隐含資訊生成的模拟樣本圖檔。

  • 第一部分
##################################################################
#  6.可視化
##################################################################
    # 根據圖檔模拟生成圖檔
    show_num = 10
    gensimple, d_class, inputx, inputy, con_out = sess.run(
        [genout, pred_class, x, y, con_fake], feed_dict={x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size]})

    f, a = plt.subplots(2, 10, figsize=(10, 2))  # figure 1000*20 , 分為10張子圖
    for i in range(show_num):
        a[0][i].imshow(np.reshape(inputx[i], (28, 28)))
        a[1][i].imshow(np.reshape(gensimple[i], (28, 28)))
        print("d_class", d_class[i], "inputy", inputy[i], "con_out", con_out[i])  # 輸出 判決預測種類/真實輸入種類/隐藏資訊
        
    plt.draw()
    plt.show()

    # 将隐含資訊分布對應的圖檔列印出來
    my_con = tf.placeholder(tf.float32, [batch_size, 2])
    myz = tf.concat(axis=1, values=[tf.one_hot(y, depth=classes_dim), my_con, z_rand])
    mygen = generator(myz)
    mygenout= tf.squeeze(mygen, -1) 
    
    my_con1 = np.ones([10, 2])
    a = np.linspace(0.0001, 0.99999, 10)
    y_input = np.ones([10])
    figure = np.zeros((28 * 10, 28 * 10))
    my_rand = tf.random_normal((10, rand_dim))
    for i in range(10):
        for j in range(10):
            my_con1[j][0] = a[i]
            my_con1[j][1] = a[j]
            y_input[j] = j
        mygenoutv = sess.run(mygenout, feed_dict={y: y_input, my_con: my_con1})
        for jj in range(10):
            digit = mygenoutv[jj].reshape(28, 28)
            figure[i * 28: (i + 1) * 28,
                   jj * 28: (jj + 1) * 28] = digit
    
    plt.figure(figsize=(10, 10))
    plt.imshow(figure, cmap='Greys_r')
    plt.show() 


           

得到的結果如下:

GAN生成對抗網絡合集(三):InfoGAN和ACGAN-指定類别生成模拟樣本的GAN(附代碼)
GAN生成對抗網絡合集(三):InfoGAN和ACGAN-指定類别生成模拟樣本的GAN(附代碼)
GAN生成對抗網絡合集(三):InfoGAN和ACGAN-指定類别生成模拟樣本的GAN(附代碼)

       可以看到前兩個結果是第一部分生成的,将原樣本與對應的模拟資料圖檔的分類、預測分類、隐含資訊列印出來;

       而最後一個結果是利用隐含資訊生成的模拟樣本圖檔,在整個【0,1】空間裡均勻抽樣,與樣本的标簽混合在一起,生成模拟資料。

       若去掉 loss_cf,隻保留 loss_cr 限制:(直接不優化模拟資料的分類資訊了,即我不努力了還不行麼)

GAN生成對抗網絡合集(三):InfoGAN和ACGAN-指定類别生成模拟樣本的GAN(附代碼)
GAN生成對抗網絡合集(三):InfoGAN和ACGAN-指定類别生成模拟樣本的GAN(附代碼)
GAN生成對抗網絡合集(三):InfoGAN和ACGAN-指定類别生成模拟樣本的GAN(附代碼)

若去掉loss_cr,隻保留loss_cf限制(沒告訴什麼是對的。即分類分對了,但與本身生成的模拟資料沒啥關系)

GAN生成對抗網絡合集(三):InfoGAN和ACGAN-指定類别生成模拟樣本的GAN(附代碼)
GAN生成對抗網絡合集(三):InfoGAN和ACGAN-指定類别生成模拟樣本的GAN(附代碼)
GAN生成對抗網絡合集(三):InfoGAN和ACGAN-指定類别生成模拟樣本的GAN(附代碼)

繼續閱讀