天天看點

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

變分自編碼器(VAE)與生成對抗網絡(GAN)是複雜分布上無監督學習最具前景的兩類方法。

本項目總結了使用變分自編碼器(Variational Autoencode,VAE)和生成對抗網絡(GAN)對給定資料分布進行模組化,并且對比了這些模型的性能。你可能會問:我們已經有了數百萬張圖像,為什麼還要從給定資料分布中生成圖像呢?正如 Ian Goodfellow 在 NIPS 2016 教程中指出的那樣,實際上有很多應用。我覺得比較有趣的一種是使用 GAN 模拟可能的未來,就像強化學習中使用政策梯度的智能體那樣。

本文組織架構:

變分自編碼器(VAE)

生成對抗網絡(GAN)

訓練普通 GAN 的難點

訓練細節

在 MNIST 上進行 VAE 和 GAN 對比實驗

在無标簽的情況下訓練 GAN 判别器

在有标簽的情況下訓練 GAN 判别器

在 CIFAR 上進行 VAE 和 GAN 實驗

延伸閱讀

VAE

變分自編碼器可用于對先驗資料分布進行模組化。從名字上就可以看出,它包括兩部分:編碼器和解碼器。編碼器将資料分布的進階表征映射到資料的低級表征,低級表征叫作本征向量(latent vector)。解碼器吸收資料的低級表征,然後輸出同樣資料的進階表征。

從數學上來講,讓 X 作為編碼器的輸入,z 作為本征向量,X′作為解碼器的輸出。

圖 1 是 VAE 的可視化圖。

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

這與标準自編碼器有何不同?關鍵差別在于我們對本征向量的限制。如果是标準自編碼器,那麼我們主要關注重建損失(reconstruction loss),即:

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

而在變分自編碼器的情況中,我們希望本征向量遵循特定的分布,通常是機關高斯分布(unit Gaussian distribution),使下列損失得到優化:

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

p(z′)∼N(0,I) 中 I 指機關矩陣(identity matrx),q(z∣X) 是本征向量的分布,其中。和由神經網絡來計算。KL(A,B) 是分布 B 到 A 的 KL 散度。

由于損失函數中還有其他項,是以存在模型生成圖像的精度,同本征向量的分布與機關高斯分布的接近程度之間存在權衡(trade-off)。這兩部分由兩個超參數λ_1 和λ_2 來控制。

GAN

GAN 是根據給定的先驗分布生成資料的另一種方式,包括同時進行的兩部分:判别器和生成器。

判别器用于對「真」圖像和「僞」圖像進行分類,生成器從随機噪聲中生成圖像(随機噪聲通常叫作本征向量或代碼,該噪聲通常從均勻分布(uniform distribution)或高斯分布中擷取)。生成器的任務是生成可以以假亂真的圖像,令判别器也無法區分出來。也就是說,生成器和判别器是互相對抗的。判别器非常努力地嘗試區分真僞圖像,同時生成器盡力生成更加逼真的圖像,目的是使判别器将這些圖像也分類為「真」圖像。

圖 2 是 GAN 的典型結構。

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

生成器包括利用代碼輸出圖像的解卷積層。圖 3 是生成器的架構圖。

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

訓練 GAN 的難點

訓練 GAN 時我們會遇到一些挑戰,我認為其中最大的挑戰在于本征向量/代碼的采樣。代碼隻是從先驗分布中對本征變量的噪聲采樣。有很多種方法可以克服該挑戰,包括:使用 VAE 對本征變量進行編碼,學習資料的先驗分布。這聽起來要好一些,因為編碼器能夠學習資料分布,現在我們可以從分布中進行采樣,而不是生成随機噪聲。

我們知道兩個分布 p(真實分布)和 q(估計分布)之間的交叉熵通過以下公式計算:

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

對于二進制分類:

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

對于 GAN,我們假設分布的一半來自真實資料分布,一半來自估計分布,是以:

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

訓練 GAN 需要同時優化兩個損失函數。

按照極小極大值算法:

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

這裡,判别器需要區分圖像的真僞,不管圖像是否包含真實物體,都沒有注意力。當我們在 CIFAR 上檢查 GAN 生成的圖像時會明顯看到這一點。

我們可以重新定義判别器損失目标,使之包含标簽。這被證明可以提高主觀樣本的品質。如:在 MNIST 或 CIFAR-10(兩個資料集都有 10 個類别)。

上述 Python 損失函數在 TensorFlow 中的實作:

在 MNIST 上進行 VAE 與 GAN 對比實驗

1. 不使用标簽訓練判别器

實驗使用了 MNIST 的 28×28 圖像,下圖中:

左側:資料分布的 64 張原始圖像

中間:VAE 生成的 64 張圖像

右側:GAN 生成的 64 張圖像

第 1 次疊代:

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

第 2 次疊代:

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

第 3 次疊代:

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

第 4 次疊代:

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

第 100 次疊代:

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

VAE(125)和 GAN(368)訓練的最終結果:

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

根據GAN疊代次數生成的gif圖:

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

顯然,VAE 生成的圖像與 GAN 生成的圖像相比,前者更加模糊。這個結果在預料之中,因為 VAE 模型生成的所有輸出都是分布平均。為了減少圖像的模糊度,我們可以使用 L1 損失來代替 L2 損失。

在第一個實驗後,作者還将在近期研究使用标簽訓練判别器,并在 CIFAR 資料集上測試 VAE 與 GAN 的性能。

使用

下載下傳 MNIST 和 CIFAR 資料集

使用 MNIST 訓練 VAE 請運作:

使用 MNIST 訓練 GAN 請運作:

想要擷取完整的指令行選項,請運作:

該模型由 generate_frq 決定生成圖檔的頻率,預設值為 1。

GAN 在 MNIST 上的訓練結果

MNIST 資料集中的樣本圖像:

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

上方是 VAE 生成的圖像,下方的圖展示了 GAN 生成圖像的過程:

在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

原文釋出時間為:2017-10-29

繼續閱讀