天天看點

WGAN學習筆記

WGAN學習筆記

剛剛接觸深度學習三個月的小白,從現在開始記錄自己的學習過程,希望各位能提出一些寶貴的意見。。

Wasserstein GAN

一、Introducion

無監督學習

如何學習一個可能的分布?學習一個density,通過定義一個參數densities,找到可以最大化似然我們資料的Pθ(θ是R中的d維向量)

WGAN學習筆記

如果真實資料是分布Pr ,Pθ是參數密度分布,目的就是要最小化KL(Pr||Pθ)

但KL distance 很容易逼近無窮大。

通常的解決辦法是add 一個噪聲限制model的分布,這就是為什麼實際上所有的生成模型都include a noise component.最簡單的情況是假設一個高帶寬的高斯噪聲去覆寫所有的example,但是噪聲的存在降低了sample的品質(make them blurry-模糊)比如recent paper23中…噪聲的數量很大,(這一塊我沒太明白,也沒有深究)

定義Z服從分布p(z),并定義映射gθ:Z→X,生成sample服從分布Pθ,通過改變θ可以改變分布,并且使它接近于真實資料的分布Pr。VAEs和GAN都是用這種方法。GAN在目标函數的定義上更加具有靈活性,including 所有F散度和exotic combinations

#f散度:用來衡量兩個機率密度p和q的差別,也就是衡量這兩個分布的相似程度

WGAN學習筆記

我們的重點是去衡量生成模型的分布與真實資料的分布之間的差距,或者說如何定義一個距離去衡量。為了優化參數ϴ,希望定義模型的分布Pθ,使得映射Pθ連續,也就是當ϴt趨近于ϴ,Pϴt趨近于Pθ。然而,分布Pθt是否收斂取決于我們計算這兩個分布之間的距離的方式,distance越weaker,越容易找到一個ϴ空間到Pθ空間的映射,因為這樣有利于分布的收斂。如果我們定義ρ是兩個分布之間的distance,我們應該找到一個損失函數。

這篇文章的貢獻主要有:

在section2,我們提供了一種理論分析,對于EM距離的表現和幾種流行的可能的距離和散度對比

在section3,定義了Wasserstein GAN最小化EM距離

在section4,我們展示了WGAN解決了GAN在訓練中的問題。在訓練WGAN的時候不需要在生成器和判别器之間保持一種極度小心的平衡,也不需要一種十分謹慎的網絡結構設計,在GAN中的的moda dropping現象也得到改善,WGAN最優點在于,在訓練判别器至最優的過程中能夠連續評估EM距離。

二、Different Distance

Χ是一個緊集,∑代表χ的所有波伊爾子集,Prob(χ)代表the space of probability measures defined onχ,定義:

WGAN學習筆記
WGAN學習筆記

(sup上确界,inf下确界)

γ(x,y)意思是要使分布Pr轉換成Pg,x對應于y的變化。EM距離就是最優的轉換計劃的損失。

Example1

例1證明了was距離(即EM距離)相比于JS,KL,TV距離的優越性。并且W(Pθ,P0)是一個在θ上連續的損失函數

Theorem1

講了最小化EM距離在神經網絡中是可行的,比JS散度更好,這幾種距離or散度由強到弱依次是:KL,JS,TV,EM是最弱的

Theorem2

講了在高次元上,JS,KL,TV,Was距離都趨近于0,即都可以做cost,但是在low dimensional manifolds隻有Was距離合适

三、Wasserstein GAN

推論2證明了W(Pr,Pθ)比JS更好優化,但是下确界難以控制。 KRduality告訴我們,

WGAN學習筆記

在K維Lipschitz連續上,對某些K,我們要解決大的問題就是:

WGAN學習筆記

如果上确界對某些w∈W可以得到(一種非常強的相似假設當證明評估的一緻性的時候),這個過程會産生一個W(Pr,Pθ)的calulation up to a 增加的constant。

推論3

讓Pr是任何可能的分布,Pθ是gθ(Z)的分布,Z是随機變量,密度是p,gθ是一個符合假設1的函數,

(證明在附錄C)

現在要找到函數f并解決最大值問題

WGAN學習筆記

我們要訓練一個參數化的神經網絡with權重w在緊空間W中,通過反向傳播,就像在GAN中一樣。W is compact implies that 所有函數fw對于隻依賴于W而不是單獨權重的K都K-Lipschitz, 為了有存在于緊空間中的參數w,我們可以做的就是在每次梯度更新後限制權重到一個固定範圍内,(比如W=[-0.01,0.01]l)

Weight clipping是一個強迫Lipschitz限制的糟糕的方法。如果clipping參數太大,就會花很長時間去訓練,若clipping太小,當層數較深時會梯度消失,or BN層沒用到

WGAN學習筆記

EM距離是連續可微的,訓練判别器越多,我們可以得到更可靠的梯度 of Was距離(WAS幾乎處處可微。JS散度會出現梯度消失。GAN的生成器學習去鑒别真實圖檔和生成圖檔非常迅速,and provide no reliable gradients information,WGAN的critic,不會飽和,且處處有明确的梯度。事實就是我們限制函數限制函數的增長 使其在不同的空間中維持線性,強迫最優的critic有這個行為。(?)

或許更重要的是,在訓練critic至最優的過程中不會出現模态坍塌(moda collapse), This is due to the fact that mode collapse comes from the fact that the optimal generator for a fixed discriminator is a sum of deltas on the points the discriminator assigns the highest values, as observed by [4] and highlighted in [11].

WGAN梯度處處明晰

四、實證結果

兩點主要benefits:

a.更好的損失度量方法,與生成器的收斂性和樣本品質相關聯

b.提高了優化過程的穩定性

4.1 實驗程序

目标分布學習的是LSUN-BEDROOMS,基線對照是DCGAN,生成樣本是3通道 64X64,用了A1中的超參數

4.2有意義的損失度量

因為WGan示範算法嘗試訓練critic f well before 每一個生成器更新,損失函數是估計的EM距離。

第一個實驗說明了估計距離和生成樣本品質的相關性,除了DCGAN,還跑了修改了生成器or生成器判别器都修改的(by 4-layer ReLU-mlp with 512個隐藏單元。訓練曲線和樣本在不同時期的訓練表現,在lower errors和高樣本品質之間有明顯的相關性。

WGAN學習筆記

上左:生成器 MLP ,4個隐藏層,每層512單元。随着訓練過程,loss連續不斷的降低,伴随着樣本品質的提升。上右,生成器是DCGAN,loss降低得非常快,圖檔品質也有提升,這兩種的critic都是沒有sigmoid的DCGAN是以loss可以做比較。下面的圖:生成器和判别器都是MLPs ,高學習率(是以訓練失敗),損失和樣本都是constant,曲線通過中值濾波( median filter)

中值濾波:中值濾波法是一種非線性平滑技術,它将每一像素點的灰階值設定為該點某鄰域視窗内的所有像素點灰階值的中值。

圖三展示了通過三種結構訓練的WGAN通過EM距離的評估

據我們所知,這是第一次GAN作品的這樣一種特性的展示,即GAN的損失函數展示了其收斂性。這種特性是極其有用的,This property is extremely useful when doing research in adversarial networks as one does not need to stare at the generated samples to figure out failure modes and to gain information on which models are doing better over others。

然而我們并不是聲稱這是評估生成模型的新方法,The constant scaling factor that depends on the critic’s architecture means it’s hard to compare models with different critics連續不斷的縮放因子依賴于critic的結構,意味着用不同的critic很難比較模型。并且critic沒有确定的容量,這使得我們很難根據EM距離的遠近得知我們的真實距離是多少。也就是說,我們成功的運用這個loss量度去證明我們實驗在不失敗前提下的可重複性,這是在訓練GANs時一個巨大的提升,以前的GANs沒有這種能力。

圖四是訓練GAN過程中的JS散度,在訓練GANs期間,生成器就是去最大化which is is a lower bound of 2JS(Pr, Pθ)−2 log 2. In the figure, we plot the quantity 1/2 L(D, gθ) + log 2, which is a lower bound of the JS distance

WGAN學習筆記

圖:JS 上左:一個MLP生成器 上右:一個DCGAN生成器,都是用标準GAN過程訓練,都有一個DCGAN判别器,兩張圖的錯誤率都有提升,右圖樣本變得更好但是JS散度升高或保持不變,樣本品質和loss之間沒有直接的相關性。下圖:生成器和判别器都是MLP,曲線不停的上升和下降,和圖檔品質無關。所有的曲線都經過了相同的中值濾波。

這些圖檔表明,随着圖檔品質的變化,。JS散度常常保持不變或者上升而不是下降。實際上JS散度保持在接近于log2約等于0.69,也是JSdistance的最大值。也就是說JSdistance飽和了,判别器的loss是0,生成器有時meaningful,有時出現模态坍塌。

當用高學習率,或者 uses a momentum based optimizer such as Adam [8] (with β1 > 0)

on the critic,WGAN訓練會變得不穩定,是以我們用RMSProp優化。

4.3提升穩定性

WGAN的一個優點是允許我們訓練critic直到穩定

後面沒怎麼看。。。

繼續閱讀