天天看點

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

“GAN ZOO”系列文章說明

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

    GAN成為當下研究熱點,相關論文數量正在以指數趨勢增長,如上圖所示。

    為了便于大家迅速追蹤研究熱點,“AI微刊”團隊持續推出“GAN ZOO”系列文章,精選典型GAN模型,對其進行精簡的解析,讓你“三分鐘”讀一篇論文。

GAN ZOO 第2節:

對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

PS:本文知識點高度密集,建議碼起來,電腦端閱讀。

本文是“GAN ZOO”系列第2節,将為您:

  • 分析原始GAN損失函數帶來梯度消失、模式崩潰等問題的原因;
  • 介紹經典改進模型模型:LSGAN、WGAN、WGAN-GP。

4. 最小二乘GAN(LSGAN):LSGAN用最小二乘損失函數替換交叉熵損失函數,加大對離群樣本的懲罰

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

本論文首次發表于2016.11.13

4.1 原始GAN的缺陷

    原始GAN使用Sigmoid交叉熵損失函數,容易造成梯度消失問題,使得生成器G的訓練不充分。

    具體闡述為:

    原始GAN的損失函數為Sigmoid交叉熵:

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

    如下圖,藍線是判别器D的真假樣本決策邊界,藍線右下方的樣本判為真,左上方判為假。由于判别器被欺騙,是以将一部分真樣本(黃色圓圈o)判為假,将一部分假樣本(藍色十字+)判為真。

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

    對于被判為“真”,但是又遠離真實樣本分布的假樣本(圖中粉色五角星☆雖然被判為真,但是離黃色圓圈o較遠),這些樣本被判别器D打上了“真”的标簽,即D(G(z))=1,是以在損失函數中表現為生成器G的損失函數值為0,如下所示:

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

    此時,G的損失函數的導數趨近于0,更新梯度趨近于0,出現梯度消失,G不能再得到訓練。

    簡單來說,有的生成樣本雖然成功欺騙了判别器D,但是其依然與真實樣本的分布相差較遠。Sigmoid交叉熵隻管真假、不管距離,不會再懲罰這種樣本,導緻生成器G出現梯度消失。

4.2 LSGAN的改進

4.2.1 LSGAN的思想

    Sigmoid交叉熵适合用于邏輯分類,而最小二乘損失函數适合線性回歸。是以,為了迫使生成樣本盡可能地拟合真實樣本的分布,本論文采用最小二乘損失函數替代Sigmoid交叉熵,緩解梯度消失問題。

4.2.2 LSGAN的模型

(1)LSGAN的損失函數:

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

    在判别器D的輸出層中去掉Sigmoid激活函數,并且在損失函數中去掉Log,使用最小二乘損失函數。使得D不僅判别真假,還懲罰離群的生成樣本(實際上,離決策面越遠的樣本對生成器更新梯度的貢獻越大),使生成樣本不斷向真實樣本分布靠近。如下圖所示:

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

4.2.3 LSGAN的缺點

    LSGAN對離群樣本的懲罰機制要求所有的生成樣本分布,導緻樣本生成的”多樣性”降低, 生成的樣本很可能隻是對真實樣本的簡單”模仿”和細微改動。

4.3 LSGAN的實驗

    作者将LSGAN用于手寫漢字資料庫(含3740個漢字),最終生成了可讀的漢字,從圖中可以看出。

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

參考

[1] “LSGAN:最小二乘生成對抗網絡”,機器之心

https://www.jiqizhixin.com/articles/2018-10-10-11;

5. Wasserstein GAN(WGAN):WGAN改善GAN的梯度消失問題、模式崩潰問題

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

本論文次發表于2017.1.26

5.1 原始GAN的缺陷

    原始GAN的損失函數存在缺陷:當D訓練得越好,G的梯度消失越嚴重,限制了G的訓練。

    具體闡述為:

    簡單了解,原始GAN生成器損失函數為:

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

    該公式通過一定變換後,可以用JS散度表示為:

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

    根據以上公式可以推理出如下三點:

  1. 生成器G的目标就是通過梯度下降法減小Pg與Pr之間的JS散度,使生成樣本分布Pg逼近真實樣本分布Pr。
  2. 但是,如果Pg與Pr之間的重疊部分接近于0,那麼其JS散度就是常數log2,其梯度為0,無法使用梯度下降法進行學習。
  3. 并且,Pg與Pr之間的重疊部分為0的機率非常大【注解1】。此外,随着D判别Pg與Pr的能力增強,重疊部分将越來越小。

    是以,原始GAN的不穩定表現為:如果D訓練得太好,G的loss趨近于常數,梯度為0,無法進行梯度下降;另一方面,如果D訓練得不好,G的梯度不穩,難以向Pr收斂【注解2】。

    PS:在原始GAN後,WGAN前,Ian Goodfellow對G的損失函數進行了改進,改為:

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

    但這個-logD(x)函數卻存“自相沖突”與“懲罰偏好”兩個問題【注解3】,導緻GAN訓練不穩定,并且容易出現模式崩潰。

【注解1】 Pg與Pr之間的重疊部分為0的機率較大

    原因是:生成器G将低維噪聲Pz映射為高維樣本Pg(比如從100維映射為784維),784維的Pg的各種變化已經被100維的Pz限定死了,也就是Pg實際上是在784維空間定義了一個100維的資料分布(學術層面上來講就是,生成樣本的分布Pg實際上是高維空間中的低維流形)。然而另一方面,Pr本身就是高維的,也就是在784維空間定義了一個784維的資料分布。類比到三維空間,Pg是二維的面或者一維的線,而Pr充滿三維空間,Pg與Pr之間的重疊部分就隻會是一個面或者一條線,在三維空間中相當于“0” (高維空間中的低維流形與高維流形之間的重疊幾乎為“0”)。

       是以,如果D接近最優,判别能力較強,也就是能夠完全将Pg與Pr分開,那麼Pg與Pr之間的JS散度就接近常數,求導為0。

【注解2】如果D訓練得不好,G的梯度不穩,難以向Pr收斂

    Pz從低維空間映射到高維空間的映射方式有無數種,映射結果又無數種可能性,如果D的判别能力較弱,G就可能不受限制,向着不滿足要求的方向映射。

【注解3】-logD(x)函數的“自相沖突”與“懲罰偏好”問題

    生成器損失函數為:

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP
    該函數可以被變換為:
GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

    根據該公式可以看出損失函數具有以下兩個問題:

(1)自相沖突:

       最小化改損失函數的時候就相當于最小化KL距離,同時最大化JS散度,自相沖突。

(2)懲罰偏好:

    KL距離是非對稱的,導緻GAN對以下兩種錯誤的懲罰力度不同:

    當Pg→0,Pr→1,即Pg的多樣性遠低于Pr,對 KL距離貢獻為0,懲罰微小;

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP
    當Pg→1,Pr→0,即Pg的多樣性遠高于Pr,對KL距離貢獻為無窮,懲罰巨大;
GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP
    基于此, G更傾向于舍棄多樣性,而生成“重複且安全”樣本,帶來模式崩潰問題。

5.2 WGAN的改進

5.2.1 WGAN的思想

    使用Wasserstein距離(Earth-Mover,EM距離【注解4】)替代原損失函數。

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

【注解4】Earth-Mover(推土機)距離

    函數中,E(||x-y||)可以了解為将Pr這堆“沙土”挪到Pg“位置”所需的能量,而W(Pg,Pr)就是在“最優路徑”下最小的能量消耗。

    Wasserstein距離相比KL散度、JS散度的優越性在于,即便兩個分布之間沒有重疊,Wasserstein距離仍然能夠反映它們的遠近,是以有連續的梯度。

5.2.2 LSGAN的模型

(1)WGAN的損失函數:

    但是Wasserstein距離應用較難,需要進行變換,經過變換後WGAN的損失函數為(具體變換見論文原文):

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

    該函數的意思就是要求搜尋所有Lipschitz常數【注解5】小于K的函數f,并取f在後面那一坨的上确界,并除以K。

       由于函數f有很多中形式,是以選擇用神經網絡來拟合或者囊括盡可能多的f。

【注解5】Lipschitz常數

    連續函數f如果在其定義域内的導數f’的絕對值|f’|滿足以下條件:

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP
    就稱這個函數是Lipschitz連續的,并且稱K為Lipschitz常數。

(2)WGAN的算法流程:

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

(3)WGAN的改進之處:

    簡單來說WGAN相對于GAN的改變就是一下四點:

  • 判别器最後一層去掉Sigmoid;
  • 生成器和判别器的loss不取log;
  • 每次得到D的參數更新值之後,将其剪切(Chip)到一個較小的區間[-c,c],使其滿足Lipschitz條件;
  • 不要用基于動量的優化算法(包括momentum和Adam),推薦RMSProp,SGD也行

參考

[1] “令人拍案叫絕的Wasserstein GAN”,知乎

https://zhuanlan.zhihu.com/p/25071913;

6. WGAN-GP:WGAN-GP改善WGAN梯度剪切問題

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

本文次發表于2017.12.25

6.1 原始WGAN的缺陷

    本文的三作Martin Ajorvsky是WGAN論文中的一作,本文是對WGAN的梯度剪切問題的改進。

    原始WGAN為實作Lipschitz連續條件,将D的參數更新值剪切到較小的區間[-c,c],這使得參數在-c與c兩點處聚集,限制拟合能力,如圖:

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

    下圖是随着判别器層數增大,梯度範數的Log值的變化曲線,可見WGAN的三條曲線都出現了梯度消失或者梯度爆炸,WGAN-GP則比較平穩。

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

6.2 WGAN-GP的改進

6.2.1 WGAN-GP的思想

    直接剪切太過于武斷,那就換一種柔和的方式。WGAN-GP将直接剪切替換為一個懲罰項,通過懲罰項限制梯度的值。

6.2.2 WGAN-GP的模型

(1)WGAN-GP的損失函數:

    WGAN-GP在原WGAN的損失函數後面添加了懲罰項:

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

    懲罰項中的1本來是Lipschitz常數K,目的是使得D的梯度既滿足Lipschitz條件(導數梯度不超過K),同時也不會太小(太小則學習太慢)。論文中為了簡便,将K定義為1。

注意事項:

  • 随機采樣:不需要對所有樣本都執行懲罰項,隻需要在真假樣本最容易混淆的區域每次随機采樣部分樣本,對其執行懲罰,這樣可以減小計算難度。
  • 懲罰項因子:懲罰項因子λ需要調試,本論文中使用的都是λ=1。
  • Batch Normalization:本文的梯度懲罰是對每個樣本單獨施加的,如果引入Batch Normalization會使得同個batch中不同樣本出現互相依賴的情況,是以建議不使用Batch Normalization,或者使用不會産生樣本依賴的Layer Normalization等。

(2)WGAN-GP的算法流程:

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

參考

[1] “WGAN-GP與WGAN及GAN的比較”,CSDN

https://blog.csdn.net/qq_38826019/article/details/80786061;

[2] “WGAN最新進展:從weight clipping到gradient penalty”,煉數成金

http://www.dataguru.cn/article-11229-1.html;

本文完

關注本公衆号“AI微刊”,背景發送“GAN ZOO”,即可獲得GAN ZOO系列論文包以及源代碼資源包。

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

微信号:AI微刊

GAN ZOO 第2節: 對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GPGAN ZOO 第2節:對原始GAN的損失函數進行改進:LSGAN、WGAN、WGAN-GP

背景發送“GAN ZOO ”,即可獲得GAN ZOO系列論文包以及源代碼資源包。

繼續閱讀