天天看點

原始GAN-pytorch-生成MNIST資料集(原理)

文章目錄

    • 1. GAN 《Generative Adversarial Nets》
      • 1.1 相關概念
      • 1.2 公式了解
      • 1.3 圖檔了解
      • 1.4 熵、交叉熵、KL散度、JS散度
      • 1.5 其他相關(正在補充!)

1. GAN 《Generative Adversarial Nets》

Ian J. Goodfellow, Jean Pouget-Abadie, Yoshua Benjio etc.

https://dl.acm.org/doi/10.5555/2969033.2969125

1.1 相關概念

生成模型:學習得到聯合機率分布 P ( x , y ) P(x,y) P(x,y),即特征x和标簽y同時出現的機率,然後可以求條件機率分布和其他機率分布。學習到的是資料生成的機制。
判别模型: 學習得到條件機率分布 P ( y ∣ x ) P(y|x) P(y∣x),即在特征x出現的情況下标記y出現的機率

學習一個分布和近似一個分布?

1.2 公式了解

GAN的似然函數(損失函數還要加上一個負号哦):

m i n G m a x D V ( D , G ) = E x ∼ P d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] (1.1) \underset{G}{min}\underset{D}{max}V(D,G) = E_{x \sim P_{data}(x)}[log D(x)]+E_{z\sim p_{z}(z)}[log(1-D(G(z)))] \tag{1.1} Gmin​Dmax​V(D,G)=Ex∼Pdata​(x)​[logD(x)]+Ez∼pz​(z)​[log(1−D(G(z)))](1.1)

為了學習資料x的分布 p g p_g pg​,定義了一個含有噪聲的變量分布 p z ( z ) p_z(z) pz​(z);V是評分方程(這個值是越大越好的),G是一個生成器,D是一個判别器;訓練D最大化真實資料和生成資料的差別,訓練G最小化真實資料和生成資料的差別;

注意這個公式有兩項,第一項是指是否能正确識别真實的資料;第二項是指是否能夠識别生成的資料;

(1) 完美D

  1. 當 D ( x ) D(x) D(x)完美識别真實資料和生成資料, E x ∼ P d a t a ( x ) [ l o g D ( x ) ] E_{x\sim P_{data}(x)}[log D(x)] Ex∼Pdata​(x)​[logD(x)]趨近于1,而 E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] E_{z\sim p_{z}(z)}[log(1-D(G(z)))] Ez∼pz​(z)​[log(1−D(G(z)))]趨近于0,整體趨近于1.
  2. 當 D D D不完美的時候,由于存在 l o g log log會使得兩項都是一個負數;那訓練的目的就是使得這個負數盡量小
  3. 是以需要最大化判别器帶來的值,來使得判别器D最佳。

(2) 完美G

  1. G隻和 E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] E_{z\sim p_{z}(z)}[log(1-D(G(z)))] Ez∼pz​(z)​[log(1−D(G(z)))]相關,如果G完美忽悠D的時候, E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] E_{z\sim p_{z}(z)}[log(1-D(G(z)))] Ez∼pz​(z)​[log(1−D(G(z)))]輸出的結果就是負無窮;
  2. 當不是那麼完美的時候,輸出的值就是一個負數;我們目的是使得這個輸出盡量小,以使得生成器最佳。
  3. 是以需要最小化生成器帶來值 E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] E_{z\sim p_{z}(z)}[log(1-D(G(z)))] Ez∼pz​(z)​[log(1−D(G(z)))]

訓練過程

訓練D說明

生成器生成的資料就是V(G,D)的第二項的輸入: g ( z ) = x g(z) = x g(z)=x,那麼對z的求和就可以變為對x的求和。

将 V ( G , D ) V(G,D) V(G,D)展開成積分/求和的形式

V ( G , D ) = ∫ x p d a t a ⋅ l o g ( D ( x ) ) d x + ∫ z p z ( z ) ⋅ l o g ( 1 − D ( g ( z ) ) ) = ∫ x p d a t a ⋅ l o g ( D ( x ) ) + p g ( x ) ⋅ l o g ( 1 − D ( x ) ) d x (1.2) \begin{aligned} V(G,D) &= \int_x p_{data} \cdot log(D(x))dx + \int_z p_z(z) \cdot log(1-D(g(z))) \\ &=\int_x p_{data} \cdot log(D(x)) + p_g(x) \cdot log(1-D(x))dx \end{aligned} \tag{1.2} V(G,D)​=∫x​pdata​⋅log(D(x))dx+∫z​pz​(z)⋅log(1−D(g(z)))=∫x​pdata​⋅log(D(x))+pg​(x)⋅log(1−D(x))dx​(1.2)

對于 任意的 ( a , b ) ∈ R 2 \ { 0 , 0 } (a,b) \in R^2 \backslash \{0,0\} (a,b)∈R2\{0,0},函數 y → a l o g ( y ) + b l o g ( 1 − y ) y \rightarrow a log(y) + blog(1-y) y→alog(y)+blog(1−y)是一個凸函數,我們需要求這個函數的最大值,就求導數

a y + b 1 − y = 0 y = a a + b \begin{aligned} \frac{a}{y}+\frac{b}{1-y} = 0 \\ y = \frac{a}{a+b} \end{aligned} ya​+1−yb​=0y=a+ba​​

則在 y = a a + b y = \frac{a}{a+b} y=a+ba​的時候有最大值,對應于判别器的機率即為:

D G ∗ ( x ) = p d a t a ( x ) p d a t a ( x ) + p g ( x ) D_G^*(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)} DG∗​(x)=pdata​(x)+pg​(x)pdata​(x)​

将最優解帶入到價值函數之中

C ( G ) = m a x D V ( G , D ) = E x ∼ p d a t a [ l o g D G ∗ ( x ) ] + E z ∼ p z [ l o g ( 1 − D G ∗ ( G ( z ) ) ) ] = E x ∼ p d a t a [ l o g p d a t a ( x ) p d a t a ( x ) + p g ( x ) ] + E x ∼ p g [ l o g p g ( x ) p d a t a ( x ) + p g ( x ) ] (1.3) \begin{aligned} C(G) &= \underset{D}{max}V(G,D) \\ &= E_{x \sim p_{data}}[log D_G^*(x)] + E_{z \sim p_z}[log(1-D_G^*(G(z)))] \\ &= E_{x \sim p_{data}}[log \frac{p_{data}(x)}{p_{data}(x) + p_g(x)}] + E_{x \sim p_g}[log \frac{p_g(x)}{p_{data}(x) + p_g(x)}] \end{aligned} \tag{1.3} C(G)​=Dmax​V(G,D)=Ex∼pdata​​[logDG∗​(x)]+Ez∼pz​​[log(1−DG∗​(G(z)))]=Ex∼pdata​​[logpdata​(x)+pg​(x)pdata​(x)​]+Ex∼pg​​[logpdata​(x)+pg​(x)pg​(x)​]​(1.3)

根據KL散度和JS散度的定義,可以将上面的公式改寫為

C ( G ) = K L ( P d a t a ∣ ∣ p d a t a + p g 2 ) + K L ( p g ∣ ∣ p d a t a + p g 2 ) − l o g ( 4 ) = 2 ⋅ J S D ( p d a t a ∣ ∣ p g ) − l o g ( 4 ) (1.4) \begin{aligned} C(G) &= KL(P_{data} || \frac{p_{data}+p_g}{2}) + KL(p_g || \frac{p_{data}+p_g}{2}) -log(4) \\ &= 2 \cdot JSD(p_{data}||p_g) - log(4) \end{aligned} \tag{1.4} C(G)​=KL(Pdata​∣∣2pdata​+pg​​)+KL(pg​∣∣2pdata​+pg​​)−log(4)=2⋅JSD(pdata​∣∣pg​)−log(4)​(1.4)

注意 p d a t a + p g 2 \frac{p_{data}+p_g}{2} 2pdata​+pg​​這裡除以2是為了保證是一個分布(即機率的積分是等于1的)

在固定D訓練G的時候,我們就是為了最小化這個 C ( G ) C(G) C(G),根據上面推導:

是以給出結論:當 p g = p d p_g = p_d pg​=pd​時, D G ∗ ( x ) = 1 2 D_G^*(x) = \frac{1}{2} DG∗​(x)=21​,是以 C ( G ) = l o g 1 2 + 1 2 = − l o g 4 C(G) = log\frac{1}{2} + \frac{1}{2} = -log4 C(G)=log21​+21​=−log4,可以得到最小的 C ( G ) C(G) C(G)

1.3 圖檔了解

原始GAN-pytorch-生成MNIST資料集(原理)

綠色是生成的分布;黑色是真實分布;藍色是判别器的分布

(b)表示訓練辨識器,使得辨識器可以非常好地區分二者

©表示訓練生成器,繼續欺騙判别器

1.4 熵、交叉熵、KL散度、JS散度

  1. (

    Entropy

    )

    K-L散度源于資訊論,常用的資訊度量機關為

    (Entropy)

    H = − ∑ i = 1 N p ( x i ) ⋅ l o g p ( x i ) H = -\sum_{i=1}^{N}p(x_i) \cdot logp(x_i) H=−i=1∑N​p(xi​)⋅logp(xi​)

    注意這個對數沒有确定的底數(可以使2、e或者10)。

熵度量了資料的資訊量,可以幫助我們了解用機率分布近似代替原始分布的時候我們到底損失了多少資訊;但問題是如何将熵值壓縮到最小值,即如何編碼可以達到最小的熵(存儲空間最優化)。

  1. 交叉熵

    : 量化兩個機率分布之間的差異

    H ( p , q ) = − ∑ x p ( x )    l o g    q ( x ) H(p,q) = -\sum_{x}p(x) \; log \; q(x) H(p,q)=−x∑​p(x)logq(x)

  2. KL散度

    kullback-Leibler divergence

    ):量化兩種機率分布 P和Q之間差異的方式,又成為

    相對熵

    将熵的定義公式稍加修改就可以得到

    K-L散度

    的定義公式:

    D K L ( P ∣ ∣ Q ) = ∑ i = 1 N p ( x i ) ⋅ ( l o g p ( x i ) − l o g q ( x i ) ) = ∑ i = 1 N p ( x i ) ⋅ l o g p ( x i ) q ( x i ) D_{KL}(P||Q) = \sum_{i=1}^{N} p(x_i) \cdot (log p(x_i) - log q(x_i)) = \sum_{i=1}^{N}p(x_i) \cdot log \frac{p(x_i)}{q(x_i)} DKL​(P∣∣Q)=i=1∑N​p(xi​)⋅(logp(xi​)−logq(xi​))=i=1∑N​p(xi​)⋅logq(xi​)p(xi​)​

    其中 p p p和 q q q分别表示資料的原始分布和近似的機率分布。

根據公式所示,K-L散度其實是資料的原始分布p和近似分布之間的對數差的期望。如果用2位底數計算,K-L散度表示資訊損失的二進制位數,下面用期望表示式展示:

D K L ( P ∣ ∣ Q ) = E [ l o g p ( x ) − q ( x ) ] D_{KL}(P||Q) = E[log p(x) - q(x)] DKL​(P∣∣Q)=E[logp(x)−q(x)]

注意:

  • 從散度的定義公式中可以看出其不符合對稱性(距離度量應該滿足對稱性)
  • KL散度非負性
  1. JS散度

    (

    Jensen-shannon divergence

    )

    由于K-L散度是非對稱的,是以對其進行修改,使得其能夠對稱,稱之為 JS散度

    (1) 設 M = 1 2 ( P + Q ) M = \frac{1}{2}(P+Q) M=21​(P+Q),則:

    D J S ( P ∣ ∣ Q ) = 1 2 D K L ( P ∣ ∣ M ) + 1 2 D K L ( Q ∣ ∣ M ) D_{JS}(P||Q) = \frac{1}{2}D_{KL}(P||M) + \frac{1}{2}D_{KL}(Q||M) DJS​(P∣∣Q)=21​DKL​(P∣∣M)+21​DKL​(Q∣∣M)

    (2) 将KL散度公式帶入上面

    D J S = 1 2 ∑ i = 1 N p ( x i ) l o g ( p ( x i ) p ( x i ) + q ( x i ) 2 ) + 1 2 ∑ i = 1 N q ( x i ) ⋅ l o g ( q ( x i ) p ( x i ) + q ( x i ) 2 ) D_{JS} = \frac{1}{2}\sum_{i=1}^{N}p(x_i)log(\frac{p(x_i)}{\frac{p(x_i) + q(x_i)}{2}}) + \frac{1}{2}\sum_{i=1}^{N}q(x_i) \cdot log(\frac{q(x_i)}{\frac{p(x_i)+q(x_i)}{2}}) DJS​=21​i=1∑N​p(xi​)log(2p(xi​)+q(xi​)​p(xi​)​)+21​i=1∑N​q(xi​)⋅log(2p(xi​)+q(xi​)​q(xi​)​)

    (3) 将 l o g log log中的 1 2 \frac{1}{2} 21​放到分子上

    D J S = 1 2 ∑ i = 1 N p ( x i ) l o g ( 2 p ( x i ) p ( x i ) + q ( x i ) ) + 1 2 ∑ i = 1 N q ( x i ) ⋅ l o g ( 2 q ( x i ) p ( x i ) + q ( x i ) ) D_{JS} = \frac{1}{2}\sum_{i=1}^{N}p(x_i)log(\frac{2p(x_i)}{p(x_i) + q(x_i)}) + \frac{1}{2}\sum_{i=1}^{N}q(x_i) \cdot log(\frac{2q(x_i)}{p(x_i)+q(x_i)}) DJS​=21​i=1∑N​p(xi​)log(p(xi​)+q(xi​)2p(xi​)​)+21​i=1∑N​q(xi​)⋅log(p(xi​)+q(xi​)2q(xi​)​)

    (4) 提出2

    D J S = 1 2 ∑ i = 1 N p ( x i ) l o g ( p ( x i ) p ( x i ) + q ( x i ) ) + 1 2 ∑ i = 1 N q ( x i ) ⋅ l o g ( q ( x i ) p ( x i ) + q ( x i ) ) + l o g ( 2 ) D_{JS} = \frac{1}{2}\sum_{i=1}^{N}p(x_i)log(\frac{p(x_i)}{p(x_i) + q(x_i)}) + \frac{1}{2}\sum_{i=1}^{N}q(x_i) \cdot log(\frac{q(x_i)}{p(x_i)+q(x_i)}) + log(2) DJS​=21​i=1∑N​p(xi​)log(p(xi​)+q(xi​)p(xi​)​)+21​i=1∑N​q(xi​)⋅log(p(xi​)+q(xi​)q(xi​)​)+log(2)

    注意這裡是因為 ∑ p ( x ) = ∑ q ( x ) = 1 \sum p(x) = \sum q(x) = 1 ∑p(x)=∑q(x)=1

JS散度的缺陷:當兩個分布完全不重疊的時候,幾遍兩個分布的中心離得很近,其JS散度都是一個常數,是以其擷取的梯度是0,是沒有辦法進行更新的。而兩個分布沒有重疊的原因:從理論和經驗而言,真實的資料分布其實是一個低維流形(不具備高維特征),而是存在一個嵌入在高次元的低維空間内。由于次元存在差異,資料很可能不存在分布的重合。

1.5 其他相關(正在補充!)

繼續閱讀