天天看點

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

文章目錄

  • 前言
  • 1 Our Objective
  • 2 Train
    • JS divergence is not suitable
    • WGAN
      • Wasserstein distance
  • 總結

前言

之前老早就聽說了GAN,然後對這個方法還不是很了解,想在今後的論文中應用它。是以來學習下李宏毅講的GAN,記個筆記。視訊位址

1 Our Objective

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

在Generator裡面,我們的目标是由Generator産生的分布(叫做 P G P_G PG​)和真正的data資料的分布(叫做 P d a t a P_{data} Pdata​)越接近越好。

如中間的一維向量而言,輸入的是Normal Distribution,Generator産生的分布和真實的data的分布越接近越好。

然後就可以定義相應的最優化數學公式:

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

其中的Div指的是 P G P_G PG​和 P d a t a P_{data} Pdata​之間的差異度。

也即為:

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

其中的w和b是Generator Network中的weight和bias,L指的是Loss function ,也就是前面提到的 P G P_G PG​和 P d a t a P_{data} Pdata​之間的差異度。

但是計算這種連續分布之間的Divergence是非常困難的。然後GAN給出了它相應的解法:采樣。

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

也就是我們甚至不需要知道 P G P_G PG​和 P d a t a P_{data} Pdata​分布的formulation,也可以通過采樣的方式來計算這兩者之間的差異度。

需要依靠Discriminator的力量了:

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

Discriminator訓練的目标是:給真正的dara打高分,給生成的data打低分。

等價于:

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

具體指的:

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

從這個公式中看出:最大化

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

等價于使得從真實datra采樣出來的D(y)越大,從Generator生成的y中采樣得到的D(y)越小。

上面具體的那個公式,其實是因為最早的GAN的論文作者為了将Discriminator等同于一個binary classifier。因為最大化V(D,G)等價于最小化分類中的交叉熵。

然後這裡有一個神奇的地方:

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

跟JS divergence有關。

前面說到的計算 P G P_G PG​和 P d a t a P_{data} Pdata​之間的差異度非常困難,現在就可以通過訓練Discriminator,

然後利用objective function得到的最大值,這個最大值跟JS divergence有關。

這裡不給出具體的證明,然後給了個形象的表達:

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

簡單說,就是當 P G P_G PG​和 P d a t a P_{data} Pdata​之間的差異度非常小的時候,兩者混在一起,很難分辨,這個優化問題就很難,那麼解得到的

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

的值就不會很大,比較小。是以小的Divergence對應小的

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

如果是差異度非常大的時候,原理是類似的。是以就可以說明

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

跟JS divergence有關。

是以就可以得到關于Generator和Discriminator求解最優化問題的公式:

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

簡單說,就是我們訓練得到的Discriminator的

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

拿去作為之前提到的差異度。

在最開始介紹GAN的算法的時候說的步驟一和步驟二就是為了解決這個Min Max問題。

也就是先固定G,訓練D得到相應公式的最大值。然後固定D,訓練G得到使得G産生最小的相應表達式的值。

至于怎樣設計不同的objective function(指的是前面說的設計Discriminator的目标函數)得到不同的Divergence,有一篇F GAN的文章列出了相應的表。

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

2 Train

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

今天學到了新的:No PAIN, No GAN。牛的。

GAN的訓練有很多小技巧(雖然我也不知道),然後李老師這次主要是想講WGAN。

JS divergence is not suitable

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

首先是說在大多數的例子當中, P G P_G PG​和 P d a t a P_{data} Pdata​是不重疊的。原因有二:

1 資料本身的内在性質: 圖檔其實是高緯空間的低緯manifold(我目前的了解是 一張比如二次元人物的頭像是高緯空間中的一個非常特殊的分布,也就是所謂的manifold。)類比于二維空間,圖像就是二維空間裡面的一條線,而兩條線除非完全重合,否則重疊的部分可以忽略不計。

2 采樣: 就算 P G P_G PG​和 P d a t a P_{data} Pdata​有重合的部分,但是我們并不知道兩者真實的分布,是以在采樣的時候很可能畫出一條泾渭分明的線來将兩者給完全區分開來。

P G P_G PG​和 P d a t a P_{data} Pdata​幾乎不重疊會給JS divergence帶來問題:

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

當P_G 和 和 和P_{data}$不重疊,JS divergence就會一直為log2。但是如上圖所示,就算兩者距離變近了,我們也看不出來差別,除非是兩者達到了重合。

直覺上就是:訓練得到的loss是沒用的,因為一直沒變嘛,然後隻有通過檢視每次訓練得到的圖檔來進行。

WGAN

Wasserstein distance

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

也叫做Earth Mover distance(推土機距離),但是對于更加複雜的distribution,計算Wasserstein distance變得困難了。

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

對于複雜的distribution,計算Wasserstein distance方法是:找到所有的moving plans,然後其中最短的平均距離定義為Wasserstein distance。

Wasserstein distance 和 JS divergence的對比:

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

前者的d會變小,使得在訓練的時候能夠看出不同的差別。

Wasserstein distance 和進化過程的相似之處:

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

計算Wasserstein distance:

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

注意D必須是一個足夠平滑的函數。

如何讓D屬于Lipschitz.有一些方法:

李宏毅機器學習|生成對抗網絡Generative Adversarial Network (GAN)|學習筆記(2)|GAN理論介紹與WGAN前言1 Our Objective2 Train總結

總結

本文介紹了GAN的理論介紹和WGAN,下一篇将介紹Generator相關内容。

繼續閱讀