天天看點

GAN生成的結果多樣性不足怎麼辦?那就再添一個鑒别器!

近期,澳洲迪肯大學圖像識别和資料分析中心發表了一篇新的論文,由Tu Dinh Nguyen, Trung Le, Hung Vu, Dinh Phung編寫,該論文就生成對抗網絡(GAN)的模式崩潰問題進行了讨論并給出了一種新的有效的解決方案 D2GAN,論文譯稿由雷鋒網 AI 科技評論編輯,原文連結請點選。

這篇文章介紹了一種解決生成對抗網絡(GAN)模式崩潰問題的方法。這種方法很直覺但是證明有效,特别是當對GAN預先設定一些限制時。在本質上,它結合了Kullback-Leibler(KL)和反向KL散度的差異,生成一個目标函數,進而利用這些分支的互補統計特性捕捉多模式下分散預估密度。這種方法稱為雙鑒别器生成對抗網絡(Dual discriminator generative adversarial nets, D2GAN),顧名思義,與GAN不同的是,D2GAN有兩個鑒别器。這兩個鑒别器仍然與一個生成器一起進行極大極小的博弈,一個鑒别器會給符合分布的資料樣本給與高獎勵,而另外一個鑒别器卻更喜歡生成器生成的資料。生成器就要嘗試同時欺騙兩個鑒别器。理論分析表明,假設使用最強的鑒别器,優化D2GAN的生成器可以讓原始資料庫和生成器産生的資料間的KL和反向KL散度最小化,進而有效地避免模式崩潰的問題。作者進行了大量的合成和真實資料庫的實驗(MNIST,CIFAR-10,STL-10,ImageNet),對比D2GAN和最新的GAN變種的方法,并進行定性定量評估。實驗結果有效地驗證了D2GAN的競争力和優越的性能,D2GAN生成樣本的品質和多樣性要比基準模型高得多,并可擴充到ImageNet資料庫。

生成式模型是研究領域的一大分支并且在最近幾年得到了飛速的成長,成功地部署到很多現代的應用中。一般的方法是通過解決密度預測問題,即學習模型分布Pmodel來預測置信度,在資料分布Pdata未知的情況下。這種方法的實作需要解決兩個基本問題。

首先,生成模型的學習表現基于訓練這些模型的目标函數的選擇。最為廣泛使用的目标,即事實标準目标,是遵循遵循最大似然估計原理,尋求模型參數以最大限度地提高訓練資料的似然性。這與最小化KL散度資料分布和模型分布上的差異的方法相似。這種最小化會導緻Pmodel覆寫Pdata的多種模式,但是可能會引起一些完全看不到的和潛在的不希望的樣本。相反地,另外一種方法通過交換參數,最小化:,一般稱其為反KL散度。觀察發現,對反KL散度準則優化模拟了模式搜尋的過程,Pmodel集中在Pdata的單一模式,而忽略了其他模式,稱這種問題為模式崩潰。

第二個問題是密度函數Pmodel公式的選擇問題。一種方法是定義一個明确的密度函數,然後直接的根據最大似然架構進行參數估計。另外一種方法是使用一個不明确的密度函數記性資料分布估計,不需要使用Pmodel的解析形式。還有一些想法是借用最小包圍球的原理來訓練生成器,訓練和生成的資料,在被映射到特征空間後,被封閉在同一個球體中。這種方法最為著名的先驅應用是生成對抗網絡(GAN),它是一種表達生成模型,具備生成自然場景的尖銳和真實圖像的能力。與大多數生成模型不同的是,GAN使用了一種激進的方法,模拟了遊戲中兩個玩家對抗的方法:一個生成器G通過從噪聲空間映射輸入空間來生成資料;鑒别器D則表現得像一個分類器,區分真實的樣本和生成器生成的僞圖像。生成器G和鑒别器D都是通過神經網絡參數化得來的,是以,這種方法可以歸類為深度生成模型或者生成神經模型。

GAN的優化實際上是一個極大極小問題,即給定一個最優的D,學習的目标變成尋找可以最小化Jensen-Shannon散度(JSD)的G:。JSD最小化的行為已經被實踐證明相較于KL散度更近似于反KL散度。這,另一方面,也導緻了之前提到的模式崩潰問題,在GAN的應用領域臭名昭著,即生成器隻能生成相似的圖檔,低熵分布,樣本種類匮乏。

近期的研究通過改進GAN的訓練方式來解決模式崩潰的問題。一個方法是使用mini-batch分辨法巧妙地讓鑒别器分辨與其他生成樣本非正常相似的圖檔。盡管這種啟發方式可以幫助快速生成具有視覺吸引力的樣本,但是它的計算代價很高,是以,通常應用于鑒别器的最後一個隐藏層。另外一個方法是把鑒别器的優化通過幾個步驟展開,在訓練中産生一個代理目标來進行生成器的更新。第三種方法是訓練多個生成器,發現不同的資料模式。同期的,還有一些其他的方法,運用autoencoders進行正則化或者輔助損失來補償丢失的模式等。這些方法都可以在一定程度上改善模式崩潰的問題,但是由此帶來了更高的計算複雜度,進而無法擴充到ImageNet這種大規模的和具有挑戰性的視覺資料庫上。

應對這些挑戰,作者們在這篇論文中提出了一種新的方法,既可以高效地避免模式崩潰問題又可以擴充到龐大的資料庫(比如:ImageNet等)。通過結合KL和反KL散度生成一個統一的目标函數,進而利用了兩種散度的互補統計特性,有效地在多模式下分散預估密度。使用GAN的架構,量化這種思路,便形成了一種新穎的生成對抗架構:鑒别器D1(通過鑒别資料來自于Pdata而不在生成分布PG中擷取高分),鑒别器D2(相反地,來自于PG而不在Pdata中)和生成器G(嘗試欺騙D1、D2兩個鑒别器)。作者将這種方法命名為雙鑒别器生成對抗網絡(D2GAN)。

實驗證明,訓練D2GAN與訓練GAN會遇到同樣的極大極小問題,通過交替更新生成器和鑒别器可以得到解決。理論分析表明,如果G、D1和D2具有足夠的容量,如非參數的限制下,在最佳點,對KL和反KL散度而言,訓練标準确實導緻了資料和模型分布之間的最小距離。這有助于模型在各種資料分布模式下進行公平的機率分布,使得生成器可一次完成資料分布恢複和生成多樣樣本。另外,作者還引入了超參數實作穩定地學習和各種散度影響的控制。

作者進行了大量的實驗,包括一個合成資料庫和具備不同特征的四個真實大規模資料庫(MNIST、CIFAR10、STL-10、ImageNet)。衆所周知,評估生成模型是非常困難的,作者花費了很多時間,使用了各種評估辦法,定量的對比D2GAN和最新的基線方法。實驗結果表明,D2GAN可以在保持生成樣本品質的同時提高樣本的多樣性。更重要的是,這種方法可以擴充到更大規模的資料庫(ImageNet),并保持具有競争力的多樣性結果和生成合理的高品質樣本圖檔。

簡而言之,這種方法具有三個重要的貢獻:(i)一種新穎的生成對抗模型,提高生成樣本的多樣性;(ii)理論分析證明這種方法的目标是優化KL和反KL散度的最小差異,并在PG=Pdata時,實作全局最優;(iii)使用大量的定量标準和大規模資料庫對這種方法進行綜合評估。

作者們的實作方法如下:

首先介紹一下生成對抗網絡(GAN),具有兩個玩家:鑒别器D和生成器G。鑒别器D(x),在資料空間中取一個點x,然後計算x在資料分布Pdata中而不是生成器G生成的機率。同時,生成器先向資料空間映射一個取自先導P(z)的噪聲向量z,擷取一個類似于訓練資料的樣本G(z),然後使用這個樣本來欺騙鑒别器。G(z)形成了一個在資料域的生成分布PG,和機率密度函數PG(x)。G和D都由神經網絡構成(見圖1a),并通過如下的極大極小優化得以學習:

學習遵循一個疊代的過程,其中鑒别器和生成器交替地更新。假設固定G,最大化D可以獲得最優鑒别器,同時,固定最優D*,最小化G可以實作最小化Jensen-Shannon(JS)散度(資料和模型分布:)。在博弈的納什均衡下,模型分布完全恢複了資料分布:PG=Pdata,進而鑒别器現在無法分辨真假資料:。

由于JS散度通過大量的實驗資料證明與反KL散度的特性相同,GAN也會有模式崩潰的問題,是以,其生成的資料樣本多樣性很低。

為了解決GAN的模式崩潰問題,下方介紹了一種架構,尋求近似分布來有效地涵蓋多模式下的多模态資料。這種方法也是基于GAN,但是有三個組成部分,包括兩個不同的鑒别器D1、D2和一個生成器G。假定一個資料空間中的樣本x,如果x是資料分布Pdata中的,D1(x)獲得高分,如果是模式分布PG中的,則獲得低分。相反地,如果x是模式分布PG中的,D2(x)獲得高分,如果是資料分布Pdata中的,D2(x)獲得低分。與GAN不同的是,得分的表現形式為R+而不是[0,1]中的機率。生成器G的角色與GAN中的相似,即從噪聲空間中映射資料與真實資料進行合成後欺騙D1和D2兩個鑒别器。這三個部分都由神經網絡參數化而成,其中D1和D2不分享它們的參數。這種方法被稱為雙鑒别器生成對抗網絡(D2GAN),見上圖1b。D1、D2和G遵循如下的極大極小公式:

其中超參數為了實作兩個目的。第一個是為了穩定化模型的學習過程。兩個鑒别器的輸出結果都是正的,D1(G(z))和D2(x)可能會變得很大并比LogD1(x)和LogD2(x)有指數性的影響,最終會導緻學習的不穩定。為了克服這個問題,降低α和β的值。第二個目的是控制KL和反KL散度對優化的影響。後面介紹過優化方法後再對這個部分進行讨論。

與GAN相似的是,通過交替更新D1、D2和G可以訓練D2GAN。

通過理論分析發現,假設G、D1和D2具備足夠的容量,如非參數的限制下,在最佳點,G可以通過最小化模型和資料分布的KL和反KL散度恢複資料分布。首先,假設生成器是固定的,通過(w.r.t)鑒别器進行優化分析:

證明:根據誘導測度定理,兩個期望相等:

當時,。目标函數可以推演如下:

考慮到裡面的函數積分,給定x,通過兩個變量D1、D2最大化函數,得到D1*(x)和D2*(x)。将D1和D2設定為0,可以得到:

 是非正數,則證明成立并得到了最大值。

接下來,,計算生成器G的最優方案G*。

證明:将D1*和D2*代入極大極小方程,得到:

 分别是KL和反KL散度。這些散度通常是非負的,并且隻在PG*=Pdata時等于0。換言之,生成器生成的分布PG*與資料分布完全等同,這就意味着由于兩個分布的傳回值都是1,兩個鑒别器在這種情況下就不能分辨真假樣本了。

如上公式中生成器的誤差表明提高α可以促進最小化KL散度()的優化,提高β可促進最小化反KL散度()的優化。通過調整α和β這兩個超參數,可以平衡KL散度和反KL散度的影響,進而有效地避免模式崩潰的問題。

在這個部分,作者進行了廣泛的實驗來驗證的提高模式覆寫率和提出的方法應用在大規模資料庫上的能力。使用一個合成的2D資料庫進行視覺和數值驗證,并使用四個真實的資料庫(具有多樣性和大規模)進行數值驗證。同時,将D2GAN和最新的GAN的應用進行對比。

從大量的實驗得出結論:(i)鑒别器的輸出具有softplus activations:,如正ReLU;(ii)Adam優化器,學習速率0.0002,一階動量0.5;(iii)64個樣本作為訓練生成器和鑒别器的minibatch訓練單元;(iv)0.2斜率的Leaky ReLU;(v)權重從各項同性的高斯(Gaussian)分布:進行初始化,0偏差。實作的過程使用了TensorFlow,并且在文章發表後釋出出來。下文将介紹實驗過程,首先是合成資料庫,然後是4個真實資料庫。

在第一個實驗中,使用已經設計好的實驗方案對D2GAN處理多模态資料的能力進行評估。特别的是,從2D混合8個高斯分布和協方差矩陣0.02I 擷取訓練資料,同時中位數分布在半徑2.0零質心的圓中。使用一個簡單的架構,包含一個生成器(兩個全連接配接隐藏層)和兩個鑒别器(一個ReLU激發層)。這個設定是相同的,是以保證了公平的對比。圖2c顯示了512個由D2GAN和基線生成的樣本。可以看出,正常的GAN産生的資料在資料分布的有效模式附近的一個單一模式上奔潰了。而unrolledGAN和D2GAN可以在8個混合部分分布資料,這就印證了能夠成功地學習多模态資料的能力。最後,D2GAN所截取的資料比unrolledGAN更精确,在各種模式下,unrolledGAN隻能集中在模式質心附近的幾個點,而D2GAN産生的樣本全分布在所有模式附近,這就意味着D2GAN産生的樣本比unrolledGAN多得多。

下一步,定量的進行生成資料品質的對比。因為已知真實的分布Pdata,隻需進行兩步測量,即對稱KL散度和Wasserstein距離。這些測量分别是對由D2GAN、unrolledGAN和GAN的10000個點歸一化直方與真實的Pdata之間的距離計算。圖2a/b再次清楚了表明了D2GAN相對于unrolled和GAN的優勢(距離越小越好);特别是Wasserstein度量,D2GAN離真實分布的距離基本上減小到0了。這些圖檔也表達了D2GAN相對于GAN(綠色曲線)和unrolledGAN(藍色曲線)在訓練時的穩定性。

下面,使用真實資料庫對D2GAN進行評估。在真實資料庫條件下,資料擁有更高的多樣性和更大的規模。對含有卷積層的網絡,根據DCGAN進行設計分析。鑒别器使用帶步長的卷積,生成器使用分步帶步長的卷積。每個層都進行批處理标準化,除了生成器輸出層和鑒别器的輸入層。鑒别器還使用Leaky ReLU 激發層,生成器使用ReLU層,除非其輸出是tanh,原因是各像素的強度值在回報到D2GAN模型前已經變換到[-1,1]的區間内。唯一的差別是,在D2GAN下,當從N(0,0.01)初始化權重時,其表現比從N(0,0.02)初始化權重的效果好。架構的細節請看論文附件。

評估生成對抗模型産生的樣本是很難的,原因有生成機率判斷标準繁多、缺乏有意義的圖像相似性度量标準。盡管生成器可以産生看似真實的圖像,但是如果這些圖像看起來非常近似,樣本依然不可使用。是以,為了量化各種模式下的圖像品質,同時生産高品質的樣本圖樣,使用各種不用的ad-hoc度量進行不同的實驗來進行D2GAN方法與各基線方法的效果對比。

首先,使用起始分值(Inception Score),計算通過:,這裡P(y|x)是通過預訓練的初始模型的圖像x的條件标簽分布,P(y)是邊際分布:。這種度量方式會給品質高的多樣的圖檔給高分,但是有時候很容易被崩潰的模式欺騙,導緻産生非常低品質的圖檔。是以,這種方式不能用于測量模型是否陷入了錯誤的模式。為了解決這個問題,對有标簽的資料庫,使用MODE score:

這裡,是訓練資料的預估标簽的經驗分布。MODE score的值可以充分的反應生成圖像的多樣性和視覺品質。

這個部分使用手寫數字圖像-MNIST,資料庫包含有60,000張訓練圖像和10,000張測試灰階圖(28*28像素),數值區間從0到9。首先,假設MNIST有10個模式,代表了資料分支的連接配接部分,分為10個數字等級。然後使用不同的超參數配置進行擴充的網格搜尋,使用兩個正則常數α和β,數值為{0.01,0.05,0.1,0.2}。為了進行公平的對比,對不同的架構使用相同的參數和全連接配接層。

評估部分,首先訓練一個簡單的但有效的3-layer卷積網絡(MNIST測試庫實作0.65%的誤差),然後将它應用于預估标簽的機率和生成樣本的MODE score計算中。圖3左顯示了3個模式下MODE score的分布。清晰的看到,D2GAN相對于标準GAN和Reg-GAN的巨大優越性,其分數的最大值基本落在區間【8.0-9.0】。值得注意的是,盡管提高網絡的複雜度,MODE score基本保持高水準。這幅圖檔中隻表現了最小網絡和最少層和隐藏單元的結果。

為了研究α和β的影響,在不同的α和β的數值下進行試驗(圖3右)。結果表明,給定α值,D2GAN可以在β達到一定數值時獲得更好的MODE score,當β數值繼續增大,MODE score降低。

MNIST-1K.   假定10個模式的标準MNIST資料庫相當簡單。是以,基于這個資料庫,作者使用一個更具挑戰性的資料庫進行測試。沿用上述的方式,假定一個新的有1000個等級的MNIST資料庫(MNIST-1K),方法為用3個随機數字組成一個RGB圖像。由此,可以組成1000個離散的模式,從000到999。

在這個實驗中,使用一個更強大的模型,鑒别器使用卷積層,生成器使用轉置卷積。通過測試模式的數量進行模型的性能評估,其中模型在25,600個樣本中至少産生一個模式,同時反KL散度分布介于模型分布(如從預訓練的MNIST分類器預測的标簽分布)和期望的資料分布之間。表1報告了D2GAN與GAN、unrolledGAN、GCGAN和Reg-GAN之間的對比。通過對比可以看出D2GAN具有極大的優勢,同時模型分布和資料分布之間的差距幾近為0。

 下面是将D2GAN應用到更廣泛的自然場景圖像上,用于驗證其在大規模資料庫上的表現。使用三個經常被使用的資料庫:CIFAR-10,STL-10和ImageNet。CIFAR-10包含50,000張32*32的訓練圖檔,有10個等級:飛機,機車,鳥,貓,鹿,狗,青蛙,馬,船和卡車(airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck)。STL-10,是ImageNet的子資料集,包含10,000張未被标記的96*96的圖檔,相對于CIFAR-10更多樣,但是少于ImageNet。将所有的圖檔向下縮小3倍至32*32分辨率後,再對網絡進行訓練。ImageNet非常龐大,擁有120百萬自然圖檔,包含1000個類别,通常ImageNet是深度網絡領域訓練使用的最為龐大和複雜的資料庫。使用這三個資料庫進行蓄念和計算,Inception score的結果如下圖和下方表格所示:

表格中和圖4中表示了Inception score在不同資料庫和不同模型上的不同值。值得注意的是,這邊的對比是在一個完美無監督的方法下,并且沒有标簽的資訊。在CIFAR-10資料庫上使用的8個基線模型,而在STL-10和ImageNet資料庫上使用了DCGAN、DFM(denoising feature matching)作對比。在D2GAN的實作上使用了與DCGAN完全一緻的網絡架構,以做公平的對比。在這三個實驗結果中,可以看出,D2GAN的表現低于DFM,但是在很大的程度高于其他任何一個基線模型。這種遜于DFM的結果印證了對進階别的特征進行自動解碼是提高多樣性的一種有效方法。D2GAN可與這種方式相容,是以融合這種做法可以為未來的研究帶來更好的效果。

最後,在圖5中展現了使用D2GAN生成的樣本圖檔。這些圖檔都是随機産生的,而不是特别挑選的。從圖檔中可以看出,D2GAN生成了可以視覺分辨的車,卡車,船和馬(在CIFAR-10資料庫的基礎上)。在STL-10的基礎上,圖檔看起來相對比較難以辨認,但是飛機,車,卡車和動物的輪廓還是可以識别的;同時圖檔還具備了多種背景,如天空,水下,山和森林(在ImageNet的基礎上)。這印證了使用D2GAN可以生成多樣性的圖檔的結論。

總結全文,作者介紹了一種全新的方法,結合KL(Kullback-Leibler)和反KL散度生成一個統一的目标函數來解決密度預測問題。這種方法利用了這兩種散度的互補統計特性來提高生成器産生的圖像的品質和樣本的多樣性。基于這個原理,作者引入了一種新的網絡,基于生成對抗網絡(GAN),由三方構成:兩個鑒别器和一個生成器,并命其為雙鑒别器生成對抗網絡(dual discriminator GAN, D2GAN)。如果設定兩個鑒别器是固定的,同時優化KL和反KL散度進行生成器的學習,通過這種方法可以幫助解決模式崩潰的問題(GAN的一大亟待突破的難點)。

作者通過大量的實驗對其提出的方法進行了驗證。這些實驗的結果證明了D2GAN的高效性和擴充性。實驗使用的資料庫包括合成資料庫和大規模真實圖檔資料庫,即MNIST、CIFAR-10,STL-10和ImageNet。相較于最新的基線方法,D2GAN更具擴充性,可以應用于業内最為龐大和複雜的資料庫ImageNet,盡管取得了比融合DFM(denoising feature matching)和GAN的方法低的Inception score,但遠遠高于其他GAN應用的實驗結果。最後,作者指出,未來的研究可以借鑒融合DFM和GAN的做法,在現有的方法基礎上增加類似半監督學習、條件架構和自動編碼等的技術,更進一步的解決生成對抗網絡在應用中的問題。

<b></b>

<b>本文作者:雪莉•休斯敦</b>