“GAN ZOO”系列文章說明
GAN成為當下研究熱點,相關論文數量正在以指數趨勢增長,如上圖所示。
為了便于大家迅速追蹤研究熱點,“AI微刊”團隊持續推出“GAN ZOO”系列文章,精選典型GAN模型,對其進行精簡的解析,讓你“三分鐘”讀一篇論文。
GAN ZOO 第2節:
對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP
PS:本文知識點高度密集,建議碼起來,電腦端閱讀。
本文是“GAN ZOO”系列第2節,将為您:
- 分析原始GAN損失函數帶來梯度消失、模式崩潰等問題的原因;
- 介紹經典改進模型模型:LSGAN、WGAN、WGAN-GP。
4. 最小二乘GAN(LSGAN):LSGAN用最小二乘損失函數替換交叉熵損失函數,加大對離群樣本的懲罰
本論文首次發表于2016.11.13
4.1 原始GAN的缺陷
原始GAN使用Sigmoid交叉熵損失函數,容易造成梯度消失問題,使得生成器G的訓練不充分。
具體闡述為:
原始GAN的損失函數為Sigmoid交叉熵:
如下圖,藍線是判别器D的真假樣本決策邊界,藍線右下方的樣本判為真,左上方判為假。由于判别器被欺騙,是以将一部分真樣本(黃色圓圈o)判為假,将一部分假樣本(藍色十字+)判為真。
對于被判為“真”,但是又遠離真實樣本分布的假樣本(圖中粉色五角星☆雖然被判為真,但是離黃色圓圈o較遠),這些樣本被判别器D打上了“真”的标簽,即D(G(z))=1,是以在損失函數中表現為生成器G的損失函數值為0,如下所示:
此時,G的損失函數的導數趨近于0,更新梯度趨近于0,出現梯度消失,G不能再得到訓練。
簡單來說,有的生成樣本雖然成功欺騙了判别器D,但是其依然與真實樣本的分布相差較遠。Sigmoid交叉熵隻管真假、不管距離,不會再懲罰這種樣本,導緻生成器G出現梯度消失。
4.2 LSGAN的改進
4.2.1 LSGAN的思想
Sigmoid交叉熵适合用于邏輯分類,而最小二乘損失函數适合線性回歸。是以,為了迫使生成樣本盡可能地拟合真實樣本的分布,本論文采用最小二乘損失函數替代Sigmoid交叉熵,緩解梯度消失問題。
4.2.2 LSGAN的模型
(1)LSGAN的損失函數:
在判别器D的輸出層中去掉Sigmoid激活函數,并且在損失函數中去掉Log,使用最小二乘損失函數。使得D不僅判别真假,還懲罰離群的生成樣本(實際上,離決策面越遠的樣本對生成器更新梯度的貢獻越大),使生成樣本不斷向真實樣本分布靠近。如下圖所示:
4.2.3 LSGAN的缺點
LSGAN對離群樣本的懲罰機制要求所有的生成樣本分布,導緻樣本生成的”多樣性”降低, 生成的樣本很可能隻是對真實樣本的簡單”模仿”和細微改動。
4.3 LSGAN的實驗
作者将LSGAN用于手寫漢字資料庫(含3740個漢字),最終生成了可讀的漢字,從圖中可以看出。
參考
[1] “LSGAN:最小二乘生成對抗網絡”,機器之心
https://www.jiqizhixin.com/articles/2018-10-10-11;
5. Wasserstein GAN(WGAN):WGAN改善GAN的梯度消失問題、模式崩潰問題
本論文次發表于2017.1.26
5.1 原始GAN的缺陷
原始GAN的損失函數存在缺陷:當D訓練得越好,G的梯度消失越嚴重,限制了G的訓練。
具體闡述為:
簡單了解,原始GAN生成器損失函數為:
該公式通過一定變換後,可以用JS散度表示為:
根據以上公式可以推理出如下三點:
- 生成器G的目标就是通過梯度下降法減小Pg與Pr之間的JS散度,使生成樣本分布Pg逼近真實樣本分布Pr。
- 但是,如果Pg與Pr之間的重疊部分接近于0,那麼其JS散度就是常數log2,其梯度為0,無法使用梯度下降法進行學習。
- 并且,Pg與Pr之間的重疊部分為0的機率非常大【注解1】。此外,随着D判别Pg與Pr的能力增強,重疊部分将越來越小。
是以,原始GAN的不穩定表現為:如果D訓練得太好,G的loss趨近于常數,梯度為0,無法進行梯度下降;另一方面,如果D訓練得不好,G的梯度不穩,難以向Pr收斂【注解2】。
PS:在原始GAN後,WGAN前,Ian Goodfellow對G的損失函數進行了改進,改為:
但這個-logD(x)函數卻存“自相沖突”與“懲罰偏好”兩個問題【注解3】,導緻GAN訓練不穩定,并且容易出現模式崩潰。
【注解1】 Pg與Pr之間的重疊部分為0的機率較大
原因是:生成器G将低維噪聲Pz映射為高維樣本Pg(比如從100維映射為784維),784維的Pg的各種變化已經被100維的Pz限定死了,也就是Pg實際上是在784維空間定義了一個100維的資料分布(學術層面上來講就是,生成樣本的分布Pg實際上是高維空間中的低維流形)。然而另一方面,Pr本身就是高維的,也就是在784維空間定義了一個784維的資料分布。類比到三維空間,Pg是二維的面或者一維的線,而Pr充滿三維空間,Pg與Pr之間的重疊部分就隻會是一個面或者一條線,在三維空間中相當于“0” (高維空間中的低維流形與高維流形之間的重疊幾乎為“0”)。
是以,如果D接近最優,判别能力較強,也就是能夠完全将Pg與Pr分開,那麼Pg與Pr之間的JS散度就接近常數,求導為0。
【注解2】如果D訓練得不好,G的梯度不穩,難以向Pr收斂
Pz從低維空間映射到高維空間的映射方式有無數種,映射結果又無數種可能性,如果D的判别能力較弱,G就可能不受限制,向着不滿足要求的方向映射。
【注解3】-logD(x)函數的“自相沖突”與“懲罰偏好”問題
生成器損失函數為:
該函數可以被變換為:根據該公式可以看出損失函數具有以下兩個問題:
(1)自相沖突:
最小化改損失函數的時候就相當于最小化KL距離,同時最大化JS散度,自相沖突。
(2)懲罰偏好:
KL距離是非對稱的,導緻GAN對以下兩種錯誤的懲罰力度不同:
當Pg→0,Pr→1,即Pg的多樣性遠低于Pr,對 KL距離貢獻為0,懲罰微小;
當Pg→1,Pr→0,即Pg的多樣性遠高于Pr,對KL距離貢獻為無窮,懲罰巨大; 基于此, G更傾向于舍棄多樣性,而生成“重複且安全”樣本,帶來模式崩潰問題。
5.2 WGAN的改進
5.2.1 WGAN的思想
使用Wasserstein距離(Earth-Mover,EM距離【注解4】)替代原損失函數。
【注解4】Earth-Mover(推土機)距離
函數中,E(||x-y||)可以了解為将Pr這堆“沙土”挪到Pg“位置”所需的能量,而W(Pg,Pr)就是在“最優路徑”下最小的能量消耗。
Wasserstein距離相比KL散度、JS散度的優越性在于,即便兩個分布之間沒有重疊,Wasserstein距離仍然能夠反映它們的遠近,是以有連續的梯度。
5.2.2 LSGAN的模型
(1)WGAN的損失函數:
但是Wasserstein距離應用較難,需要進行變換,經過變換後WGAN的損失函數為(具體變換見論文原文):
該函數的意思就是要求搜尋所有Lipschitz常數【注解5】小于K的函數f,并取f在後面那一坨的上确界,并除以K。
由于函數f有很多中形式,是以選擇用神經網絡來拟合或者囊括盡可能多的f。
【注解5】Lipschitz常數
連續函數f如果在其定義域内的導數f’的絕對值|f’|滿足以下條件:
就稱這個函數是Lipschitz連續的,并且稱K為Lipschitz常數。
(2)WGAN的算法流程:
(3)WGAN的改進之處:
簡單來說WGAN相對于GAN的改變就是一下四點:
- 判别器最後一層去掉Sigmoid;
- 生成器和判别器的loss不取log;
- 每次得到D的參數更新值之後,将其剪切(Chip)到一個較小的區間[-c,c],使其滿足Lipschitz條件;
- 不要用基于動量的優化算法(包括momentum和Adam),推薦RMSProp,SGD也行
參考
[1] “令人拍案叫絕的Wasserstein GAN”,知乎
https://zhuanlan.zhihu.com/p/25071913;
6. WGAN-GP:WGAN-GP改善WGAN梯度剪切問題
本文次發表于2017.12.25
6.1 原始WGAN的缺陷
本文的三作Martin Ajorvsky是WGAN論文中的一作,本文是對WGAN的梯度剪切問題的改進。
原始WGAN為實作Lipschitz連續條件,将D的參數更新值剪切到較小的區間[-c,c],這使得參數在-c與c兩點處聚集,限制拟合能力,如圖:
下圖是随着判别器層數增大,梯度範數的Log值的變化曲線,可見WGAN的三條曲線都出現了梯度消失或者梯度爆炸,WGAN-GP則比較平穩。
6.2 WGAN-GP的改進
6.2.1 WGAN-GP的思想
直接剪切太過于武斷,那就換一種柔和的方式。WGAN-GP将直接剪切替換為一個懲罰項,通過懲罰項限制梯度的值。
6.2.2 WGAN-GP的模型
(1)WGAN-GP的損失函數:
WGAN-GP在原WGAN的損失函數後面添加了懲罰項:
懲罰項中的1本來是Lipschitz常數K,目的是使得D的梯度既滿足Lipschitz條件(導數梯度不超過K),同時也不會太小(太小則學習太慢)。論文中為了簡便,将K定義為1。
注意事項:
- 随機采樣:不需要對所有樣本都執行懲罰項,隻需要在真假樣本最容易混淆的區域每次随機采樣部分樣本,對其執行懲罰,這樣可以減小計算難度。
- 懲罰項因子:懲罰項因子λ需要調試,本論文中使用的都是λ=1。
- Batch Normalization:本文的梯度懲罰是對每個樣本單獨施加的,如果引入Batch Normalization會使得同個batch中不同樣本出現互相依賴的情況,是以建議不使用Batch Normalization,或者使用不會産生樣本依賴的Layer Normalization等。
(2)WGAN-GP的算法流程:
參考
[1] “WGAN-GP與WGAN及GAN的比較”,CSDN
https://blog.csdn.net/qq_38826019/article/details/80786061;
[2] “WGAN最新進展:從weight clipping到gradient penalty”,煉數成金
http://www.dataguru.cn/article-11229-1.html;
本文完
關注本公衆号“AI微刊”,背景發送“GAN ZOO”,即可獲得GAN ZOO系列論文包以及源代碼資源包。
微信号:AI微刊
背景發送“GAN ZOO ”,即可獲得GAN ZOO系列論文包以及源代碼資源包。