天天看點

GAN(對抗生成網絡)的數學原理及基本算法

GAN在生成任務上與其他方法對比

Machine Learning (ML) 本質上是尋找一個函數 f : X → Y f:X\to Y f:X→Y,通過網絡來近似這個函數。Structured Learning (SL) 輸出相對于ML更加複雜,可能是圖、樹、序列……通常ML的問題,每個類别都會有一些樣本,但是SL則不會——輸出可能是輸入從來沒見過的東西。

在GAN 之前,auto-encoder (AE) 非常常用。AE結構:輸入 → \to → encoder → \to → vector c → \to → decoder → \to →輸出。訓練的時候要使得輸入輸出盡可能相近。當做生成任務的時候,截取AE的decoder部分,随機給vector c,輸出即生成的結果。是以可見AE可以用于做生成——即将decoder輸出視為生成資訊。但是這種訓練方式面對一個問題,假設A、B都是訓練集資訊,針對A、B網絡能夠很好的進行生成,但是當面對 0.5 A + 0.5 B 0.5A+0.5B 0.5A+0.5B網絡将會不知道輸出應該是什麼(最大的可能是兩個圖像的堆疊)。

對AE的改進叫做variational-AE (VAE) 在之前模型結構的基礎上,對輸入加上了噪聲,其餘不變。這種操作能夠讓模型能加穩定。

VAE同樣有一個問題,下圖是四種生成的情況。由于通常會以衡量輸入輸出的相似度作為評估标準,以L1或L2為例來講,下圖第一行的生成loss将會小于第二行。但從人的角度出發,第一行添加或者失去的一個像素點使得圖像更加不真實,相反,第二行盡管圖像與原圖差别更大,但是更加真實。

GAN(對抗生成網絡)的數學原理及基本算法

VAE的另一個問題是,如果資料分布較為分散,從降低訓練loss的情況出發,更傾向于産生資料分布介于分散分布之間。但往往,這樣會使得生成結果非常不真實。

GAN能夠彌補上述模型的缺陷。SL做生成任務,通常兩種思考方式:bottom-up和top-down。前者是通過一個一個小元件完成生成任務,但往往會失去大局觀。top-down是從大體上生成,但很難在細節上完成生成。GAN的通常構成包括generator(生成器)和discriminator(辨識器)。從SL的角度分析GAN,則可以将GAN看做同時具有bottom-up和top-down結構的生成模型。generator從細節上考慮如何生成,可視為bottom-up,discriminator從宏觀考慮生成效果是否真實,可視為top-down。

GAN背後的數學原理

首先,GAN的結構通常包括generator和discriminator,前者用于生成,後者對生成結果進行評價。下面介紹GAN背後的數學原理。

假設随機變量 X ∼ P d a t a ( x ) X\thicksim P_{data}(x) X∼Pdata​(x),通過GAN的generator生成的結果有: X ∼ P G ( x ) X\thicksim P_{G}(x) X∼PG​(x)。那麼如果生成器效果越好,那麼兩個分部将會越相近。通常使用最大似然進行優化。從 P d a t a ( x ) P_{data}(x) Pdata​(x)中進行采樣,擷取樣本 { x 1 , x 2 , x 3 … x m } \{ x_1, x_2, x_3 \dots x_m\} {x1​,x2​,x3​…xm​},那麼目标就是要去最大化這些樣本産生的log likelihood: L = arg max ⁡ θ l o g ∏ i = 1 m P G ( x ; θ ) \begin{aligned} L = \argmax \limits_{\theta} log \prod\limits_{i=1}^m P_G(x;\theta) \end{aligned} L=θargmax​logi=1∏m​PG​(x;θ)​這個公式的意思就是從 P d a t a P_{data} Pdata​采樣出的樣本用 P G P_G PG​這個分布來産生,log likelihood越大越好。對上述公式進行進一步化簡:

L = arg max ⁡ θ l o g ∏ i = 1 m P G ( x i ; θ ) L = arg max ⁡ θ ∑ i = 1 m l o g P G ( x i ; θ ) L ≈ arg max ⁡ θ E x ∼ P d a t a [ l o g P G ( x ; θ ) ] L = arg max ⁡ θ ∫ x P d a t a ( x ) l o g P G ( z ; θ ) d x \begin{aligned} L &= \argmax \limits_{\theta} log \prod\limits_{i=1}^m P_G(x_i;\theta)\\ L &= \argmax \limits_{\theta} \sum\limits_{i=1}^m log P_G(x_i;\theta)\\ L &\approx \argmax \limits_{\theta} E_{x\thicksim P_{data}}[log P_G(x;\theta)]\\ L &= \argmax \limits_{\theta} \int\limits_x P_{data}(x)log P_G(z;\theta)dx \end{aligned} LLLL​=θargmax​logi=1∏m​PG​(xi​;θ)=θargmax​i=1∑m​logPG​(xi​;θ)≈θargmax​Ex∼Pdata​​[logPG​(x;θ)]=θargmax​x∫​Pdata​(x)logPG​(z;θ)dx​上面公式的最後一個,可知,是關于 θ \theta θ進行求最大,是以,如果在該公式後面添加一項與 θ \theta θ無關的項将不會影響使該函數最大時 θ \theta θ的值,是以,上式改為: L = arg max ⁡ θ ∫ x P d a t a ( x ) l o g P G ( x ; θ ) d x − ∫ x P d a t a ( x ) l o g P d a t a ( x ) d x L = \argmax \limits_{\theta} \int\limits_x P_{data}(x)log P_G(x;\theta)dx - \int\limits_x P_{data}(x)log P_{data}(x)dx L=θargmax​x∫​Pdata​(x)logPG​(x;θ)dx−x∫​Pdata​(x)logPdata​(x)dx上面的公式實際上是 P d a t a P_{data} Pdata​和 P G P_G PG​的KL散度的相反數,實際上上面的式子等同于: L = arg min ⁡ θ ( P d a t a ∣ ∣ P G ) L = \argmin\limits_{\theta}(P_{data}||P_G) L=θargmin​(Pdata​∣∣PG​)從GAN的角度了解上面的式子,generator定義了一個分布: P G P_G PG​而discriminator即作為評估者,計算 L L L。

進行進一步的普适化——衡量兩個分布之間的差異不僅僅KL散度可行,類似JS散度等也可以。是以,重新定義generator的作用,即定義一個分布 P G P_G PG​,generator的作用就是:

G ∗ = arg min ⁡ θ D i v ( P G , P d a t a ) G^*=\argmin\limits_\theta Div(P_G,P_{data}) G∗=θargmin​Div(PG​,Pdata​)其中 D i v Div Div表示某一種散度(divergence)。此時如果是已知 P d a t a P_{data} Pdata​分布,則可以很好的解決上述問題(例如資料服從正态分布,那麼設定正态分布參數未知,使用BP可以進行很好的拟合),但現在分布未知。

前文說過discriminator類似一個評估者,目的是分别真實樣本還是生成樣本——可以視為二分類任務。那麼,可公式化其目标函數:

V ( G , D ) = E x ∼ P d a t a [ l o g D ( x ) ] + E x ∼ P G [ l o g ( 1 − D ( x ) ) ] V(G,D) = E_{x\thicksim P_{data}}[logD(x)] + E_{x\thicksim P_G}[log(1-D(x))] V(G,D)=Ex∼Pdata​​[logD(x)]+Ex∼PG​​[log(1−D(x))]那麼discriminator訓練的目标就是最大化這個式子,即 D ∗ = arg max ⁡ D V ( D , G ) D^* = \argmax\limits_D V(D,G) D∗=Dargmax​V(D,G)假設 D ( x ) D(x) D(x)可能是任何函數——實際上無法做到,隻有當參數無窮多的時候才能做到,将上式進行繼續改寫:

V ( D , G ) = ∫ x [ P d a t a ( x ) l o g ( D ( x ) ) + P G ( x ) l o g ( 1 − D ( x ) ) ] d x V(D,G) = \int\limits_x [P_{data}(x)log(D(x)) + P_G(x)log(1-D(x))]dx V(D,G)=x∫​[Pdata​(x)log(D(x))+PG​(x)log(1−D(x))]dx當訓練discriminator的時候通常會将generator進行固定,此時可将 P G P_G PG​視為常數, P d a t a P_{data} Pdata​顯然是一個常數。那麼上式可以寫做 V = ∫ x a l o g ( D ( x ) ) + b l o g ( 1 − D ( x ) ) d x V = \int\limits_x a log(D(x)) + b log(1-D(x))dx V=x∫​alog(D(x))+blog(1−D(x))dx要使得積分最大,如果每一個位置被積函數最大,那麼積分最大。通過對被積函數 f = a l o g ( D ) + b l o g ( 1 − D ) f = a log(D) + b log(1-D) f=alog(D)+blog(1−D),令導數等于零:

∂ f ∂ x = 0 a D = b 1 − D \begin{aligned} \frac{\partial{f}}{\partial{x}} &= 0\\ \frac{a}{D} &= \frac{b}{1-D} \end{aligned} ∂x∂f​Da​​=0=1−Db​​可以得到 D D D的表示并回帶進 V V V(後面省略分布裡面的 x x x),讓被積函數分子分母同時乘以 1 2 \frac{1}{2} 21​:

D = P d a t a P d a t a + P G V = ∫ x P d a t a l o g 1 2 P d a t a 1 2 ( P d a t a + P G ) d x + ∫ x P G l o g 1 2 P G 1 2 ( P d a t a + P G ) d x V = − 2 l o g 2 + ∫ x P d a t a l o g P d a t a 1 2 ( P d a t a + P G ) d x + ∫ x P G l o g w P G 1 2 ( P d a t a + P G ) d x V = − 2 l o g 2 + K L ( P d a t a ∣ ∣ P d a t a + P G 2 ) + K L ( P G ∣ ∣ P d a t a + P G 2 ) V = − 2 l o g 2 + 2 J S D ( P d a t a ∣ ∣ P G ) \begin{aligned} D &= \frac{P_{data}}{P_{data} + P_G}\\ V &= \int\limits_x P_{data}log\frac{\frac{1}{2}P_{data}}{\frac{1}{2}(P_{data}+P_G)}dx + \int\limits_x P_Glog\frac{\frac{1}{2}P_{G}}{\frac{1}{2}(P_{data}+P_G)}dx \\ V &= -2log2 + \int\limits_x P_{data}log\frac{P_{data}}{\frac{1}{2}(P_{data}+P_G)}dx + \int\limits_x P_Glog\frac{wP_{G}}{\frac{1}{2}(P_{data}+P_G)}dx \\ V &= -2log2 + KL(P_{data}||\frac{P_{data} + P_G}{2}) + KL(P_{G}||\frac{P_{data} + P_G}{2})\\ V &= -2log2 +2JSD(P_{data}||P_G) \end{aligned} DVVVV​=Pdata​+PG​Pdata​​=x∫​Pdata​log21​(Pdata​+PG​)21​Pdata​​dx+x∫​PG​log21​(Pdata​+PG​)21​PG​​dx=−2log2+x∫​Pdata​log21​(Pdata​+PG​)Pdata​​dx+x∫​PG​log21​(Pdata​+PG​)wPG​​dx=−2log2+KL(Pdata​∣∣2Pdata​+PG​​)+KL(PG​∣∣2Pdata​+PG​​)=−2log2+2JSD(Pdata​∣∣PG​)​其中 J S D JSD JSD表示JS散度。是以,discriminator的任務就是辨識生成資料和原始訓練資料分布的不同,具體衡量名額是JS散度。

定義完discriminator借着從數學角度了解generator。generator的目标任務是: G ∗ = arg min ⁡ G ( arg max ⁡ D V ( G , D ) ) G^* = \argmin\limits_G(\argmax\limits_D V(G,D)) G∗=Gargmin​(Dargmax​V(G,D))在訓練G的時候discriminator是固定的,是以訓練目标可以修改為: L o s s = E x ∼ P G [ l o g ( 1 − D ( x ) ) ] Loss = E_{x\thicksim P_G}[log(1-D(x))] Loss=Ex∼PG​​[log(1−D(x))]對于GAN的具體訓練有兩個需要注意的點

  • generator和discriminator參數更新略有不同
  • 剛開始 D ( x ) D(x) D(x)非常小

    針對第一個點,參考下圖:

    GAN(對抗生成網絡)的數學原理及基本算法

    兩次更新之後的generator由于差異較大,将造成同一個discriminator不能同等地衡量這兩個generator的效率——從數學角度來講,訓練generator的時候,依舊假設D的最大值點不變,但是因為實際上G的參數更新過多, D 0 ∗ D_0^* D0∗​已經不再是全局最大值點了,這樣訓練的G将會産生問題。是以實際操作,每次疊代,discriminator訓練到底,generator訓練更新次數較少。

    針對第二點,剛開始的generator生成效果差,是以從函數圖像看:

    GAN(對抗生成網絡)的數學原理及基本算法

    l o g ( 1 − D ( x ) ) log(1-D(x)) log(1−D(x))非常接近0,那麼可知斜率很小,參數更新非常艱難。是以這裡使用 − l o g ( D ( x ) ) -log(D(x)) −log(D(x))替代原來的函數。從圖像上可以看到兩者的大小對應——同時最大、最小。

    參考資料:李宏毅老師GAN講解視訊

繼續閱讀