本文作者為前谷歌進階工程師、ai 初創公司 wavefront 創始人兼 cto dev nag,介紹了他是如何用不到五十行代碼,在 pytorch 平台上完成對 gan 的訓練。雷鋒網編譯整理。
dev nag
在進入技術層面之前,為照顧新入門的開發者,雷鋒網先來介紹下什麼是 gan。
2014 年,ian goodfellow 和他在蒙特利爾大學的同僚發表了一篇震撼學界的論文。沒錯,我說的就是《generative adversarial nets》,這标志着生成對抗網絡(gan)的誕生,而這是通過對計算圖和博弈論的創新性結合。他們的研究展示,給定充分的模組化能力,兩個博弈模型能夠通過簡單的反向傳播(backpropagation)來協同訓練。
這兩個模型的角色定位十分鮮明。給定真實資料集 r,g 是生成器(generator),它的任務是生成能以假亂真的假資料;而 d 是判别器 (discriminator),它從真實資料集或者 g 那裡擷取資料, 然後做出判别真假的标記。ian goodfellow 的比喻是,g 就像一個赝品作坊,想要讓做出來的東西盡可能接近真品,蒙混過關。而 d 就是文物鑒定專家,要能區分出真品和高仿(但在這個例子中,造假者 g 看不到原始資料,而隻有 d 的鑒定結果——前者是在盲幹)。
理想情況下,d 和 g 都會随着不斷訓練,做得越來越好——直到 g 基本上成為了一個“赝品制造大師”,而 d 因無法正确區分兩種資料分布輸給 g。
實踐中,ian goodfellow 展示的這項技術在本質上是:g 能夠對原始資料集進行一種無監督學習,找到以更低次元的方式(lower-dimensional manner)來表示資料的某種方法。而無監督學習之是以重要,就好像雷鋒網反複引用的 yann lecun 的那句話:“無監督學習是蛋糕的糕體”。這句話中的蛋糕,指的是無數學者、開發者苦苦追尋的“真正的 ai”。
dev nag:在表面上,gan 這門如此強大、複雜的技術,看起來需要編寫天量的代碼來執行,但事實未必如此。我們使用 pytorch,能夠在 50 行代碼以内建立出簡單的 gan 模型。這之中,其實隻有五個部分需要考慮:
r:原始、真實資料集
i:作為熵的一項來源,進入生成器的随機噪音
g:生成器,試圖模仿原始資料
d:判别器,試圖差別 g 的生成資料和 r
我們教 g 糊弄 d、教 d 當心 g 的“訓練”環。
1.) r:在我們的例子裡,從最簡單的 r 着手——貝爾曲線(bell curve)。它把平均數(mean)和标準差(standard deviation)作為輸入,然後輸出能提供樣本資料正确圖形(從 gaussian 用這些參數獲得 )的函數。在我們的代碼例子中,我們使用 4 的平均數和 1.25 的标準差。
2.) i:生成器的輸入是随機的,為提高點難度,我們使用均勻分布(uniform distribution )而非标準分布。這意味着,我們的 model g 不能簡單地改變輸入(放大/縮小、平移)來複制 r,而需要用非線性的方式來改造資料。
3.) g: 該生成器是個标準的前饋圖(feedforward graph)——兩層隐層,三個線性映射(linear maps)。我們使用了 elu (exponential linear unit)。g 将從 i 獲得平均分布的資料樣本,然後找到某種方式來模仿 r 中标準分布的樣本。
4.) d: 判别器的代碼和 g 的生成器代碼很接近。一個有兩層隐層和三個線性映射的前饋圖。它會從 r 或 g 那裡獲得樣本,然後輸出 0 或 1 的判别值,對應反例和正例。這幾乎是神經網絡的最弱版本了。
5.) 最後,訓練環在兩個模式中變幻:第一步,用被準确标記的真實資料 vs. 假資料訓練 d;随後,訓練 g 來騙過 d,這裡是用的不準确标記。道友們,這是正邪之間的較量。
即便你從沒接觸過 pytorch,大概也能明白發生了什麼。在第一部分(綠色),我們讓兩種類型的資料經過 d,并對 d 的猜測 vs. 真實标記執行不同的評判标準。這是 “forward” 那一步;随後我們需要 “backward()” 來計算梯度,然後把這用來在 d_optimizer step() 中更新 d 的參數。這裡,g 被使用但尚未被訓練。
在最後的部分(紅色),我們對 g 執行同樣的操作——注意我們要讓 g 的輸出穿過 d (這其實是送給造假者一個鑒定專家來練手)。但在這一步,我們并不優化、或者改變 d。我們不想讓鑒定者 d 學習到錯誤的标記。是以,我們隻執行 g_optimizer.step()。
這就完成了。據雷鋒網(公衆号:雷鋒網)了解,還有一些其他的樣闆代碼,但是對于 gan 來說隻需要這五個部分,沒有其他的了。
在 d 和 g 之間幾千輪交手之後,我們會得到什麼?判别器 d 會快速改進,而 g 的進展要緩慢許多。但當模型達到一定性能之後,g 才有了個配得上的對手,并開始提升,巨幅提升。
兩萬輪訓練之後,g 的輸入平均值超過 4,但會傳回到相當平穩、合理的範圍(左圖)。同樣的,标準差一開始在錯誤的方向降低,但随後攀升至理想中的 1.25 區間(右圖),達到 r 的層次。
是以,基礎資料最終會與 r 吻合。那麼,那些比 r 更高的時候呢?資料分布的形狀看起來合理嗎?畢竟,你一定可以得到有 4.0 的平均值和 1.25 标準內插補點的均勻分布,但那不會真的符合 r。我們一起來看看 g 生成的最終分布。
結果是不錯的。左側的尾巴比右側長一些,但偏離程度和峰值與原始 gaussian 十分相近。g 接近完美地再現了原始分布 r——d 落于下風,無法分辨真相和假相。而這就是我們想要得到的結果——使用不到 50 行代碼。
該說的都說完了,老司機請上 github 把玩全套代碼。
本文作者:三川