天天看點

【GAN優化】GAN訓練的小技巧

頭一陣子放假了,專欄都沒有怎麼更新了,今天開始繼續更新(想問問小夥伴們都放了多久的假期?我們隻有兩周感覺時間好短呀~)

作者&編輯 | 小米粥

上一期中,我們說明了GAN訓練中的幾個問題,例如由于把判别器訓練得太好而引起的梯度消失的問題、通過采樣估算距離而造成偏差的問題、minmax問題不清晰以及模式崩潰、優化選擇在參數空間而非函數空間的問題等,今天這篇小文将從博弈論的角度出發來審視一下GAN訓練時的問題,說明訓練GAN其實是在尋找納什均衡,然後說明達到納什均衡或者說損失函數收斂是很難的,并最後給出了3個穩定訓練的小技巧。

1 博弈論與GAN

大家對GAN的基本模型想必已經非常熟悉了,我們先從博弈論的角度來重新描述GAN模型。遊戲中有兩個玩家:D(判别器)和G(生成器),D試圖在判别器的參數空間上尋找最好的解使得它的損失函數最小:

【GAN優化】GAN訓練的小技巧

G也試圖在生成器的參數空間上尋找最好的解使得它的損失函數最小:

【GAN優化】GAN訓練的小技巧

需要說明,D和G并不是彼此獨立的,對于GAN,整個博弈是“交替進行決策”的。例如先确定生成器G的參數,則D會在給定的G的參數的條件下更新判别器的參數以此最小化D的損失函數,如下面中藍線過程(提升D的辨識能力);接着G會在給定的D的參數的條件下更新判别器的參數以此來最小化G的損失函數,如下面中綠線過程(提升G的生成能力)......直到達到一個穩定的狀态:納什均衡。

【GAN優化】GAN訓練的小技巧

在納什均衡點,兩者的參數到達一種“制衡”狀态。在給定G的參數情況下,D目前的參數便對應了D損失函數的最小值,同樣在給定D的參數情況下,G目前的參數便對應了G損失函數的最小值,也就是說在交替更新過程中,D和G均不可能單獨做出任何改變。

解空間中可能存在多個納什均衡點,而且納什均衡點并不意味着全局最優解,但是是一種經過多次博弈後的穩定狀态,是以說GAN的任務是并非尋找全局最優解,而是尋找一個納什均衡狀态,損失函數收斂即可。在損失函數非凸、參數連續、參數空間次元很高的情況下,不可能通過嚴格的數學計算去更新參數進而找到納什均衡,在GAN中,每次參數更新(對應藍線、綠線表示的過程)使用的是梯度下降法;另外,每次D或者G對自身參數更新都會減少自身的損失函數同時加大對方的損失函數,這導緻了尋找GAN的納什均衡是比較困難的。

這裡有一個比GAN簡單多的例子表明很多時候納什均衡的狀态難以達到:

【GAN優化】GAN訓練的小技巧

使用梯度下降法發現x,y在參數空間中并不會收斂到納什均衡點(0,0),損失函數的表現為:不收斂。

【GAN優化】GAN訓練的小技巧

針對GAN訓練的收斂性問題,我們接下來将介紹幾種啟發式的訓練技巧。

2 特征比對

在GAN中,判别器D輸出一個0到1之間的标量表示接受的樣本來源于真實資料集的機率,而生成器的訓練目标就是努力使得該标量值最大。如果從特征比對(feature matching)的角度來看,整個判别器D(x)由兩部分功能組成,先通過前半部分f(x)提取到樣本的抽象特征,後半部分的神經網絡根據抽象特征進行判定分類,即

【GAN優化】GAN訓練的小技巧

f(x)表示判别器中截止到中間某層神經元激活函數的輸出。在訓練判别器時,我們試圖找到一種能夠區分兩類樣本的特征提取方式f(x),而在訓練生成器的時候,我們可以不再關注D(x)的機率輸出,我們可以關注:從生成器生成樣本中用f(x)提取的抽象特征是否與在真實樣本中用f(x)提取的抽象特征相比對,另外,為了比對這兩個抽象特征的分布,考慮其一階統計特征:均值,即可将生成器的目标函數改寫為:

【GAN優化】GAN訓練的小技巧

采用這樣的方式,我們可以讓生成器不過度訓練,讓訓練過程相對穩定一些。

3 曆史均值

曆史均值(historical averaging)是一個非常簡單方法,就是在生成器或者判别器的損失函數中添加一項:

【GAN優化】GAN訓練的小技巧

這樣做使得判别器或者生成器的參數不會突然産生較大的波動,直覺上看,在快要達到納什均衡點時,參數會在納什均衡點附近不斷調整而不容易跑出去。這個技巧在處理低維問題時确實有助于進入納什均衡狀态進而使損失函數收斂,但是GAN中面臨的是高維問題,助力可能有限。

4 單側标簽平滑

标簽平滑(label smoothing)方法最開始在1980s就提出過,它在分類問題上具有非常廣泛的應用,主要是為了解決過拟合問題。一般的,我們的分類器最後一層使用softmax層輸出分類機率(Sigmoid隻是softmax的特殊情況),我們用二分類softmax函數來說明一下标簽平滑的效果。

對于給定的樣本x,其類别為1,則标簽為[1,0],如果不用标簽平滑,隻使用“硬”标簽,其交叉熵損失函數為:

【GAN優化】GAN訓練的小技巧

這時候通過最小化交叉熵損失函數來訓練分類器,本質上是使得:

【GAN優化】GAN訓練的小技巧

其實也就是使得:

【GAN優化】GAN訓練的小技巧

對于給定的樣本x,使z1的值無限大(當然這在實際中是不可能的)而使z2趨于0,無休止拟合該标簽1,便産生了過拟合、降低了分類器的泛化能力。如果使用标簽平滑手段,對給定的樣本x,其類别為1,例如平滑标簽為[1-ε ,ε],交叉損失函數為:

【GAN優化】GAN訓練的小技巧

當損失函數達到最小值時,有:

【GAN優化】GAN訓練的小技巧

選擇合适的參數,理論上的最優解z1與z2存在固定的常數內插補點(此內插補點由ε決定),便不會出現z1無限大,遠大于z2的情況了。如果将此技巧用在GAN的判别器中,即對生成器生成的樣本輸出機率值0變為β ,則生成器生成的單樣本交叉熵損失函數為:

【GAN優化】GAN訓練的小技巧

而對資料集中的樣本打标簽由1降為α,則資料集中的單樣本交叉熵損失函數為:

【GAN優化】GAN訓練的小技巧

總交叉損失函數為:

【GAN優化】GAN訓練的小技巧

求導容易得其最優解D(x)為:

【GAN優化】GAN訓練的小技巧

實際訓練中,有大量這樣的x:其在訓練資料集中機率分布為0,而在生成器生成的機率分布不為0,他們經過判别器後輸出為β。為了能迅速“識破”該樣本,最好将β降為0,這就是所謂的單側标簽平滑。

訓練GAN時,我們對它的要求并不是找到全局最優解,能進入一個納什均衡狀态、損失函數收斂就可以了。(雖然這個納什均衡狀态可能非常糟糕)最近的幾篇文章将着重于讨論GAN訓練的收斂問題。

[1] Müller, Rafael, S. Kornblith , and G. Hinton . "When Does Label Smoothing Help?." 2019

[2] Salimans T , Goodfellow I , Zaremba W , et al. Improved Techniques for Training GANs[J]. 2016.

總結

這篇文章闡述了GAN的訓練其實是一個尋找納什均衡狀态的過程,然而想采用梯度下降達到收斂是比較難的,最後給出了幾條啟發式的方法幫助訓練收斂。

下期預告:GAN訓練中的動力學

【GAN優化】GAN訓練的小技巧