天天看點

GAN 很複雜?如何用不到 50 行代碼訓練 GAN(基于 PyTorch)

本文作者為前谷歌進階工程師、ai 初創公司 wavefront 創始人兼 cto dev nag,介紹了他是如何用不到五十行代碼,在 pytorch 平台上完成對 gan 的訓練。雷鋒網編譯整理。

GAN 很複雜?如何用不到 50 行代碼訓練 GAN(基于 PyTorch)

dev nag

在進入技術層面之前,為照顧新入門的開發者,雷鋒網先來介紹下什麼是 gan。

2014 年,ian goodfellow 和他在蒙特利爾大學的同僚發表了一篇震撼學界的論文。沒錯,我說的就是《generative adversarial nets》,這标志着生成對抗網絡(gan)的誕生,而這是通過對計算圖和博弈論的創新性結合。他們的研究展示,給定充分的模組化能力,兩個博弈模型能夠通過簡單的反向傳播(backpropagation)來協同訓練。

這兩個模型的角色定位十分鮮明。給定真實資料集 r,g 是生成器(generator),它的任務是生成能以假亂真的假資料;而 d 是判别器 (discriminator),它從真實資料集或者 g 那裡擷取資料, 然後做出判别真假的标記。ian goodfellow 的比喻是,g 就像一個赝品作坊,想要讓做出來的東西盡可能接近真品,蒙混過關。而 d 就是文物鑒定專家,要能區分出真品和高仿(但在這個例子中,造假者 g 看不到原始資料,而隻有 d 的鑒定結果——前者是在盲幹)。

GAN 很複雜?如何用不到 50 行代碼訓練 GAN(基于 PyTorch)

理想情況下,d 和 g 都會随着不斷訓練,做得越來越好——直到 g 基本上成為了一個“赝品制造大師”,而 d 因無法正确區分兩種資料分布輸給 g。

實踐中,ian goodfellow 展示的這項技術在本質上是:g 能夠對原始資料集進行一種無監督學習,找到以更低次元的方式(lower-dimensional manner)來表示資料的某種方法。而無監督學習之是以重要,就好像雷鋒網反複引用的 yann lecun 的那句話:“無監督學習是蛋糕的糕體”。這句話中的蛋糕,指的是無數學者、開發者苦苦追尋的“真正的 ai”。

dev nag:在表面上,gan 這門如此強大、複雜的技術,看起來需要編寫天量的代碼來執行,但事實未必如此。我們使用 pytorch,能夠在 50 行代碼以内建立出簡單的 gan 模型。這之中,其實隻有五個部分需要考慮:

r:原始、真實資料集

i:作為熵的一項來源,進入生成器的随機噪音

g:生成器,試圖模仿原始資料

d:判别器,試圖差別 g 的生成資料和 r

我們教 g 糊弄 d、教 d 當心 g 的“訓練”環。

1.) r:在我們的例子裡,從最簡單的 r 着手——貝爾曲線(bell curve)。它把平均數(mean)和标準差(standard deviation)作為輸入,然後輸出能提供樣本資料正确圖形(從 gaussian 用這些參數獲得 )的函數。在我們的代碼例子中,我們使用 4 的平均數和 1.25 的标準差。

GAN 很複雜?如何用不到 50 行代碼訓練 GAN(基于 PyTorch)

2.) i:生成器的輸入是随機的,為提高點難度,我們使用均勻分布(uniform distribution )而非标準分布。這意味着,我們的 model g 不能簡單地改變輸入(放大/縮小、平移)來複制 r,而需要用非線性的方式來改造資料。

GAN 很複雜?如何用不到 50 行代碼訓練 GAN(基于 PyTorch)

3.) g: 該生成器是個标準的前饋圖(feedforward graph)——兩層隐層,三個線性映射(linear maps)。我們使用了 elu (exponential linear unit)。g 将從 i 獲得平均分布的資料樣本,然後找到某種方式來模仿 r 中标準分布的樣本。

GAN 很複雜?如何用不到 50 行代碼訓練 GAN(基于 PyTorch)

4.) d: 判别器的代碼和 g 的生成器代碼很接近。一個有兩層隐層和三個線性映射的前饋圖。它會從 r 或 g 那裡獲得樣本,然後輸出 0 或 1 的判别值,對應反例和正例。這幾乎是神經網絡的最弱版本了。

GAN 很複雜?如何用不到 50 行代碼訓練 GAN(基于 PyTorch)

5.) 最後,訓練環在兩個模式中變幻:第一步,用被準确标記的真實資料 vs. 假資料訓練 d;随後,訓練 g 來騙過 d,這裡是用的不準确标記。道友們,這是正邪之間的較量。

GAN 很複雜?如何用不到 50 行代碼訓練 GAN(基于 PyTorch)

即便你從沒接觸過 pytorch,大概也能明白發生了什麼。在第一部分(綠色),我們讓兩種類型的資料經過 d,并對 d 的猜測 vs. 真實标記執行不同的評判标準。這是 “forward” 那一步;随後我們需要 “backward()” 來計算梯度,然後把這用來在 d_optimizer step() 中更新 d 的參數。這裡,g 被使用但尚未被訓練。

在最後的部分(紅色),我們對 g 執行同樣的操作——注意我們要讓 g 的輸出穿過 d (這其實是送給造假者一個鑒定專家來練手)。但在這一步,我們并不優化、或者改變 d。我們不想讓鑒定者 d 學習到錯誤的标記。是以,我們隻執行 g_optimizer.step()。

這就完成了。據雷鋒網(公衆号:雷鋒網)了解,還有一些其他的樣闆代碼,但是對于 gan 來說隻需要這五個部分,沒有其他的了。

在 d 和 g 之間幾千輪交手之後,我們會得到什麼?判别器 d 會快速改進,而 g 的進展要緩慢許多。但當模型達到一定性能之後,g 才有了個配得上的對手,并開始提升,巨幅提升。

兩萬輪訓練之後,g 的輸入平均值超過 4,但會傳回到相當平穩、合理的範圍(左圖)。同樣的,标準差一開始在錯誤的方向降低,但随後攀升至理想中的 1.25 區間(右圖),達到 r 的層次。

GAN 很複雜?如何用不到 50 行代碼訓練 GAN(基于 PyTorch)

是以,基礎資料最終會與 r 吻合。那麼,那些比 r 更高的時候呢?資料分布的形狀看起來合理嗎?畢竟,你一定可以得到有 4.0 的平均值和 1.25 标準內插補點的均勻分布,但那不會真的符合 r。我們一起來看看 g 生成的最終分布。

GAN 很複雜?如何用不到 50 行代碼訓練 GAN(基于 PyTorch)

結果是不錯的。左側的尾巴比右側長一些,但偏離程度和峰值與原始 gaussian 十分相近。g 接近完美地再現了原始分布 r——d 落于下風,無法分辨真相和假相。而這就是我們想要得到的結果——使用不到 50 行代碼。

該說的都說完了,老司機請上 github 把玩全套代碼。

本文作者:三川

繼續閱讀