天天看點

從GAN到WGAN到WDGRL誤差函數的深入淺出解讀

這篇部落格分為三部分,先介紹GAN的loss函數,以及它存在的問題;接下來第二節介紹WGAN的loss函數由來,以及實作細節;最後介紹在程式中使用最多的WGAN-GP的loss函數。

  1. 傳統GAN的訓練困難原因

傳統GAN的Loss,該Loss有些不足的地方,導緻了GAN的訓練十分困難,表現為:

1、模式坍塌,即生成樣本的多樣性不足;

2、不穩定,收斂不了。

原因總結如下:

用KL Divergence和JS Divergence作為兩個機率的差異的衡量,最關鍵的問題是若兩個機率的支撐集不重疊,就無法讓那個參數化的、可移動的機率分布慢慢地移動過來,以拟合目标分布。(即KL散度和JS散度在兩個機率分布沒有重疊的情況下,無法反應兩者之間的差異性,是以無法進行學習優化。)

GAN的誤差函數:

判别器D的loss函數:

判别器判斷真實樣本的得分D(x)越高越好,判斷生成樣本的得分D(G(Z))越低越好

生成器G的loss函數:

生成器的目标是生成的樣本在判别器得分D (G(Z))越高越好。

  1. WGAN

Wasserstein Distance:(兩個機率分布的距離衡量名額)

定義如下:

第一句話的解釋很漂亮:

W(Pr,Pg)是這兩個機率分布的距離,它是兩個在同一空間上(次元相同)的随機變量x,y之差的範數均值的下确界。

下确界:某個集合X 的子集 E 的下确界(英語:infimum 或infima,記為inf E )是小于或等于的E 所有其他元素的最大元素,其不一定在E 內。

轉化為

f(x)是函數集 中的一個函數。 表示滿足1-Lipschitz條件的函數集。(Lipschitz條件是一個比通常連續更強的光滑性條件。直覺上,Lipschitz連續函數限制了函數改變的速度,符合利Lipschitz條件的函數的斜率,必小于一個稱為Lipschitz常數的實數)。

用K-Lipschitz條件代替:

Sup指上确界,inf指下确界

式要求得到上确界,上确界的具體函數形式我們不知道,但我們可以用神經網絡來逼近它,這是判别器(Discriminator)的作用,也就是Discriminator網絡充當了f(x)的角色,是以(4)等價于:

其中, 是樣本函數平均值

判别器D,目标是這個距離越大越好,

是以判别器的損失函數:

生成器隻能調節生成器參數,不能調節判别器參數,是以

這個距離越小越好.

參考:https://blog.csdn.net/StreamRock/article/details/81138621

其中的pytorch源碼,清楚地解釋了如果在程式中得到判别器和生成器的loss,其中WGAN對權重進行了修剪:

# Clip weights of discriminator

for p in discriminator.parameters():

p.data.clamp_(-opt.clip_value, opt.clip_value)

要保證fθ(x)滿足K-Lipschitz條件,夾逼了判别器的參數。

關于WGAN的loss函數,我發現這個總結更為精辟:

WGAN中,判别器D和生成器G的loss函數分别是:

  1. WGAN-GP

參考: https://blog.csdn.net/omnispace/article/details/77790497(解釋很精彩)

大部分程式中采用WGAN-GP(Gradient penalty)。

在引入梯度懲罰項之前,先介紹采用參數夾逼的方式存在的兩個問題:

  1. 判别器loss希望盡可能拉大真假樣本的分數差,然而weight clipping獨立地限制每一個網絡參數的取值範圍,在這種情況下我們可以想象,最優的政策就是盡可能讓所有參數走極端,要麼取最大值(如0.01)要麼取最小值(如-0.01)。

這樣帶來的結果就是,判别器會非常傾向于學習一個簡單的映射函數(想想看,幾乎所有參數都是正負0.01,都已經可以直接視為一個二值神經網絡了,太簡單了)。而作為一個深層神經網絡來說,這實在是對自身強大拟合能力的巨大浪費!判别器沒能充分利用自身的模型能力,經過它回傳給生成器的梯度也會跟着變差。

  1. 第二個問題,weight clipping會導緻很容易一不小心就梯度消失或者梯度爆炸。原因是判别器是一個多層網絡,如果我們把clipping threshold設得稍微小了一點,每經過一層網絡,梯度就變小一點點,多層之後就會指數衰減;反之,如果設得稍微大了一點,每經過一層網絡,梯度變大一點點,多層之後就會指數爆炸。隻有設得不大不小,才能讓生成器獲得恰到好處的回傳梯度,然而在實際應用中這個平衡區域可能很狹窄,就會給調參工作帶來麻煩。相比之下,gradient penalty就可以讓梯度在後向傳播的過程中保持平穩。

既然判别器希望盡可能拉大真假樣本的分數差距,那自然是希望梯度越大越好,變化幅度越大越好,是以判别器在充分訓練之後,其梯度norm其實就會是在K附近。知道了這一點,我們可以把上面的loss改成要求梯度norm離K越近越好,效果是類似的:

簡單地把K定為1,再跟WGAN原來的判别器loss權重合并,就得到新的判别器loss:

三個loss項均是期望的形式,在實際中通過采樣的方式獲得。前面兩個期望的采樣我們都熟悉,第一個期望是從真樣本集裡面采,第二個期望是從生成器的噪聲輸入分布采樣後,再由生成器映射到樣本空間。可是第三個分布要求我們在整個樣本空間 上采樣,這完全不科學!由于所謂的次元災難問題,如果要通過采樣的方式在圖檔或自然語言這樣的高維樣本空間中估計期望值,所需樣本量是指數級的,實際上沒法做到。

我們其實沒必要在整個樣本空間上施加Lipschitz限制,隻要重點抓住生成樣本集中區域、真實樣本集中區域以及夾在它們中間的區域就行了。具體來說,我們先随機采一對真假樣本,還有一個0-1的随機數:

在 和 的連線上随機插值采樣,

·  weight clipping是對樣本空間全局生效,但因為是間接限制判别器的梯度norm,會導緻一不小心就梯度消失或者梯度爆炸;

·  gradient penalty隻對真假樣本集中區域、及其中間的過渡地帶生效,但因為是直接把判别器的梯度norm限制在1附近,是以梯度可控性非常強,容易調整到合适的尺度大小。

這個采用點的擷取可以用下圖表示:

  1. 從真實資料 PdataPdata 中采樣得到一個點
  2. 從生成器生成的資料 PGPG 中采樣得到一個點
  3. 為這兩個點連線
  4. 線上上随機采樣得到一個點作為 Ppenalty的點。

注意:由于我們是對每個樣本獨立地施加梯度懲罰,是以判别器的模型架構中不能使用Batch Normalization,因為它會引入同個batch中不同樣本的互相依賴關系。如果需要的話,可以選擇其他normalization方法,如Layer Normalization、Weight Normalization和Instance Normalization,這些方法就不會引入樣本之間的依賴。論文推薦的是Layer Normalization。

繼續閱讀