天天看點

【GANs學習筆記】(一)初步了解GANs第一章 初步了解GANs

完整筆記:http://www.gwylab.com/note-gans.html

——————————————————————-

本章借鑒内容:

https://blog.csdn.net/sxf1061926959/article/details/54630462

http://www.gwylab.com/paper-gans.html

第一章 初步了解GANs

1. 生成模型與判别模型

      了解對抗網絡,首先要了解生成模型和判别模型。判别模型比較好了解,就像分類一樣,有一個判别界限,通過這個判别界限去區分樣本。從機率角度分析就是獲得樣本x屬于類别y的機率,是一個條件機率P(y|x)。而生成模型是需要在整個條件内去産生資料的分布,就像高斯分布一樣,需要去拟合整個分布,從機率角度分析就是樣本x在整個分布中的産生的機率,即聯合機率P(xy)。具體可以參考博文:

      http://blog.csdn.net/zouxy09/article/details/8195017

2. 對抗網絡思想

       了解了生成模型和判别模型後,再來了解對抗網絡就很直接了,對抗網絡隻是提出了一種網絡結構,總體來說, GANs簡單的想法就是用兩個模型,一個生成模型,一個判别模型。判别模型用于判斷一個給定的圖檔是不是真實的圖檔(從資料集裡擷取的圖檔),生成模型的任務是去創造一個看起來像真的圖檔一樣的圖檔。而在開始的時候這兩個模型都是沒有經過訓練的,這兩個模型一起對抗訓練,生成模型産生一張圖檔去欺騙判别模型,然後判别模型去判斷這張圖檔是真是假,最終在這兩個模型訓練的過程中,兩個模型的能力越來越強,最終達到穩态。(本書僅介紹GANs在計算機視覺方面的應用,但是GANs的用途很廣,不單單是圖像,其他方面,譬如文本、語音,或者任何隻要含有規律的資料合成,都能用GANs實作。)

3. 詳細實作過程

      假設我們現在的資料集是手寫體數字的資料集minst,生成模型的輸入可以是二維高斯模型中一個随機的向量,生成模型的輸出是一張僞造的fake image,同時通過索引擷取資料集中的真實手寫數字圖檔real image,然後将fake image和real image一同傳給判别模型,由判别模型給出real還是fake的判别結果。于是,一個簡單的GANs模型就搭建好了。

      值得注意的是,生成模型G和判别模型D可以是各種各樣的神經網絡,對抗網絡的生成模型和判别模型沒有任何限制。

【GANs學習筆記】(一)初步了解GANs第一章 初步了解GANs

3.1 前向傳播階段

1. 模型輸入

      1

我們随機産生一個随機向量作為生成模型的資料,然後經過生成模型後産生一個新的向量,作為Fake Image,記作D(z)。

      2

從資料集中随機選擇一張圖檔,将圖檔轉化成向量,作為Real Image,記作x。

2. 模型輸出

      将由1或者2産生的輸出,作為判别網絡的輸入,經過判别網絡後輸出值為一個0到1之間的數,用于表示輸入圖檔為Real Image的機率,real為1,fake為0。

      使用得到的機率值計算損失函數,解釋損失函數之前,我們先解釋下判别模型的輸入。根據輸入的圖檔類型是Fake Image或Real Image将判别模型的輸入資料的label标記為0或者1。即判别模型的輸入類型為(

xfake

,0)或者(

xreal

,1)。

3.2 反向傳播階段

1. 優化目标

      原文給了這麼一個優化函數:

【GANs學習筆記】(一)初步了解GANs第一章 初步了解GANs

      我們來了解一下這個目标公式,先優化D,再優化G,拆解之後即為如下兩步:

      第一步

:優化D

【GANs學習筆記】(一)初步了解GANs第一章 初步了解GANs

      優化D,即優化判别網絡時,沒有生成網絡什麼事,後面的G(z)就相當于已經得到的假樣本。優化D的公式的第一項,使得真樣本x輸入的時候,得到的結果越大越好,因為真樣本的預測結果越接近1越好;對于假樣本G(z),需要優化的是其結果越小越好,也就是D(G(z))越小越好,因為它的标簽為0。但是第一項越大,第二項越小,就沖突了,是以把第二項改為1-D(G(z)),這樣就是越大越好。

      第二步

:優化G

【GANs學習筆記】(一)初步了解GANs第一章 初步了解GANs

      在優化G的時候,這個時候沒有真樣本什麼事,是以把第一項直接去掉,這時候隻有假樣本,但是這個時候希望假樣本的标簽是1,是以是D(G(z))越大越好,但是為了統一成1-D(G(z))的形式,那麼隻能是最小化1-D(G(z)),本質上沒有差別,隻是為了形式的統一。之後這兩個優化模型可以合并起來寫,就變成最開始的最大最小目标函數了。

      我們依據上面的優化目标函數,便能得到如下模型最終的損失函數。

2. 判别模型的損失函數

【GANs學習筆記】(一)初步了解GANs第一章 初步了解GANs

      當輸入的是從資料集中取出的real Iamge 資料時,我們隻需要考慮第二部分,D(x)為判别模型的輸出,表示輸入x為real 資料的機率,我們的目的是讓判别模型的輸出D(x)的輸出盡量靠近1。

      當輸入的為fake資料時,我們隻計算第一部分,G(z)是生成模型的輸出,輸出的是一張Fake Image。我們要做的是讓D(G(z))的輸出盡可能趨向于0。這樣才能表示判别模型是有區分力的。

      相對判别模型來說,這個損失函數其實就是交叉熵損失函數。計算loss,進行梯度反傳。這裡的梯度反傳可以使用任何一種梯度修正的方法。

      當更新完判别模型的參數後,我們再去更新生成模型的參數。

3. 生成模型的損失函數

【GANs學習筆記】(一)初步了解GANs第一章 初步了解GANs

      對于生成模型來說,我們要做的是讓G(z)産生的資料盡可能的和資料集中的資料一樣。就是所謂的同樣的資料分布。那麼我們要做的就是最小化生成模型的誤差,即隻将由G(z)産生的誤差傳給生成模型。

      但是針對判别模型的預測結果,要對梯度變化的方向進行改變。當判别模型認為G(z)輸出為真實資料集的時候和認為輸出為噪聲資料的時候,梯度更新方向要進行改變。

      即最終的損失函數為:
【GANs學習筆記】(一)初步了解GANs第一章 初步了解GANs

      其中

【GANs學習筆記】(一)初步了解GANs第一章 初步了解GANs

表示判别模型的預測類别,對預測機率取整,為0或者1.用于更改梯度方向,門檻值可以自己設定,或者正常的話就是0.5。

4. 反向傳播

      我們已經得到了生成模型和判别模型的損失函數,這樣分開看其實就是兩個單獨的模型,針對不同的模型可以按照自己的需要去是實作不同的誤差修正,我們也可以選擇最常用的BP做為誤差修正算法,更新模型參數。

      其實說了這麼多,生成對抗網絡的生成模型和判别模型是沒有任何限制,生成對抗網絡提出的隻是一種網絡結構,我們可以使用任何的生成模型和判别模型去實作一個生成對抗網絡。當得到損失函數後就安裝單個模型的更新方法進行修正即可。