天天看點

生成對抗網絡GANs(筆記二)代碼

一、概述

生成對抗網絡(GAN,Generative Adversatial Networks)作為一種學習模型,是近年來無監督學習上最具前景的方法之一。被Yann Lecun 贊歎 GAN 是機器學習近十年來最有意思的想法。

生成器(Generative model):看做一個樣本生成器,它接受一個噪聲信号作為輸入,在經過一系列處理變成一個模拟樣本。

判别器(Discrimination model):看做一個二進制分類器,接受真實樣本x和模拟樣本x',将真實樣本作為機率1輸出,模拟樣本作為0輸出。

生成對抗網絡GANs(筆記二)代碼

訓練過程:生成器和判别器的博弈過程,過程和其它網絡的訓練差不多,都是向着梯度下降的方向優化代價函數,訓練結果是兩者達到納什均衡。這個時候生成器生成的模拟樣本和真實樣本已經看不出差異,是以經過判别器的時候,判别器無法判斷出他的輸入時真實樣本還是模拟樣本。

GAN是要幹什麼?

       嗯,假設我們有一堆資料,這些資料沒有标簽,比如我們有一堆的人臉圖檔,各種人臉,都不知道誰是誰,隻是有一堆的臉。然後我們想要通過這一堆資料生成新的資料(原始的論文做的工作),如上圖:目标是利用一個輸入的噪聲信号模拟得到一些人臉資料,這些生成的資料和原有資料很相似,人眼無法看出來差別。論文裡面Generative Adversarial Net講的假币的例子真是形象生動:一個造假币的團隊(生成器)和抓造假的警察(判别器),一開始造假币的團隊造假技術不過關,是以造出的假币總是被警察看穿,在這個過程中他們就要不停的提升自己的造假水準以避免被抓,而警察一開始可以很輕易的就判斷出假币。但是随着團隊造假技術的不斷提高,警察可能都判斷不出來是真币還是假币,是以警察也要在這個過程中不斷提高自己的判斷水準。最後造假團隊的技術上升到了一個高度,同時警察的判斷能力也達到了一定的高度,但是,造假團隊的假币警察都判斷不出來真假了。即0.5的幾率判斷出來是假币,0.5的幾率判斷出是真币。那麼我們訓練的目的就達到了。這就是一個雙人博弈!

看過論文的都應該認識下面的幾個公式:

生成對抗網絡GANs(筆記二)代碼
生成對抗網絡GANs(筆記二)代碼

詳細:設J(D)是一個判别網絡的目标函數——一個交叉熵(cross entropy)函數,J(D(x))左邊的部分D(x)表示判斷出x是真x的情況,右邊部分表示D判别出有生成網絡G把噪聲資料z給為造出來的情況。J(G)表示生成網絡的目标函數,他的目的是和D反着幹,是以在前面加了負号,類似一個Jensen-Shannon(JS)距離。

GAN的優化目标有兩個:優化判别器D和和優化生成器G,将第一個公式拆兩個部分,注意D()代表的是網絡判斷圖檔是否真實的機率:

1、優化判别器D的時候,我們希望D的鑒别能力可以達到最大,而log函數是一個單調遞增函數,是以指數最大就好了。是以D(x)最大而(1-D(G(z)))也要最大(即D(G(z))最小),等于是說D能判斷出來輸入是來自于生成模型G。優化G的時候,我們要讓G最小,看公式的第一項,沒有涉及G,所可看做常數項忽略就好,隻需要優化後面的部分,這時的G(z)應該是接近真實樣本的即D(G(z))最大。最小和最大自然的就産生了博弈,最終的結果是判别器的判别能力從1慢慢降到了0.5。這就是我們要找的均衡點(納什均衡)也就是J(D)的鞍點(saddle point)

有公式來解釋:

對于D:

生成對抗網絡GANs(筆記二)代碼

如圖,黑色點是真實資料data;綠色線是模型生成的僞資料model,是由映射過去的。藍色的線是我們要學習的D,它的目的是要把data和model的分布區分開,謝偉公式就是data和model分布相加做分母,分子是真是的data分布。最終的效果是D無線接近于1/2 = 0.5。也就是說Pdata和Pmodel無限相似,D再也無法辨識真僞資料的差別。最終的結果如下圖:

生成對抗網絡GANs(筆記二)代碼

但是一個問題就是:達到這樣的結果之後,生成模型就沒有辦法再學習了。因為1/2的導數永遠是0。

為了解決這個問題,除了把兩者對抗做成最小最大博弈,還可以把它寫成非飽和(Non-Saturating)博弈:

生成對抗網絡GANs(筆記二)代碼

也就是說用G自己的僞裝成功率來表示自己的目标函數(不再是直接拿J(D) 的負數)。這樣的話,我們的均衡就不再是由損失(loss)決定的了。J(D) 跟J(G) 沒有簡單粗暴的互相綁定,就算在D完美了以後,G還可以繼續被優化。

代碼:這裡的例子是模拟論文中生成的高斯分布。

位址:https://github.com/MrRenQIANG/GANs

關鍵代碼解析:

1、感覺機部分,将随機噪聲轉換成适合判别器輸入的次元。

生成對抗網絡GANs(筆記二)代碼

2、指數衰減的學習率,以及動量優化方法:

生成對抗網絡GANs(筆記二)代碼

3、使用均方誤差的判别器預訓練

生成對抗網絡GANs(筆記二)代碼

4、網絡結構以及整體的優化函數

生成對抗網絡GANs(筆記二)代碼

5、訓練和結果

生成對抗網絡GANs(筆記二)代碼

6、對抗過程中,D,G的變化趨勢

生成對抗網絡GANs(筆記二)代碼

Reference:

http://c.m.163.com/news/a/C7UE2MLT0511AQHO.html?spss=newsapp&spsw=1