天天看點

看穿機器學習(W-GAN模型)的黑箱

看穿機器學習(W-GAN模型)的黑箱
看穿機器學習(W-GAN模型)的黑箱

圖a. principle of gan.

前兩天紐約暴雪,天地一片蒼茫。今天元宵節,長島依然清冷寂寥,正月十五鬧花燈的喧嚣熱鬧已成為悠遠的回憶。這學期,老顧在講授一門研究所學生水準的數字幾何課程,目前講到了2016年和丘成桐先生、羅鋒教授共同完成的一個幾何定理【3】,這個工作給出了經典亞曆山大定理(alexandrov theorem)的構造性證明,也給出了最優傳輸理論(optimal mass transportation)的一個幾何解釋。這幾天,機器學習領域的wasserstein gan突然變得火熱,其中關鍵的概念可以完全用我們的理論來給出幾何解釋,這允許我們在一定程度上親眼“看穿”傳統機器學習中的“黑箱”。下面是老顧下周一授課的講稿。

生成對抗網絡 gan

訓練模型 生成對抗網絡gan (generative adversarial networks)是一個“自相沖突”的系統,就是以己之矛克以己之盾,在沖突中發展,使得矛更加鋒利,盾更加強韌。這裡的矛被稱為是判别器(descriminator),這裡的盾被稱為是生成器(generator)。

看穿機器學習(W-GAN模型)的黑箱

圖b. generative model.

生成器g一般是将一個随機變量(例如高斯分布,或者均勻分布),通過參數化的機率生成模型(通常是用一個深度神經網來進行參數化),進行機率分布的逆變換采樣,進而得到一個生成的機率分布。判别器d也通常采用深度卷積神經網。

看穿機器學習(W-GAN模型)的黑箱

圖1. gan的算法流程圖。

沖突的交鋒過程如下:給定真實的資料,其内部的統計規律表示為機率分布

看穿機器學習(W-GAN模型)的黑箱

,我們的目的就是能夠找出

看穿機器學習(W-GAN模型)的黑箱

。為此,我們制作了一個随機變量生成器g,g能夠産生随機變量,其機率分布是

看穿機器學習(W-GAN模型)的黑箱

,我們希望

看穿機器學習(W-GAN模型)的黑箱

盡量接近

看穿機器學習(W-GAN模型)的黑箱

。為了區分真實機率分布

看穿機器學習(W-GAN模型)的黑箱

和生成機率分布

看穿機器學習(W-GAN模型)的黑箱

,我們又制作了一個判别器d,給定一個樣本,d來複制判别這個樣本是來自真實資料還是來自僞造資料。goodfellow給gan中的判别器設計了如下的損失函數(lost function), 盡可能将真實樣本判為正例,生成樣本判為負例:

看穿機器學習(W-GAN模型)的黑箱

第一項不依賴于生成器g, 此式也可以定義gan中的生成器的損失函數。

在訓練中,判别器d和生成器g交替學習,最終達到納什均衡(零和遊戲),判别器無法區分真實樣本和生成樣本。

優點 gan具有非常重要的優越性。當真實資料的機率分布不可計算的時候,傳統依賴于資料内在解釋的生成模型無法直接應用。但是gan依然可以使用,這是因為gan引入了内部對抗的訓練機制,能夠逼近一下難以計算的機率分布。更為重要的,yann lecun一直積極倡導gan,因為gan為無監督學習提供了一個強有力的算法架構,而無監督學習被廣泛認為是通往人工智能重要的一環。

缺點 原始gan形式具有緻命缺陷:判别器越好,生成器的梯度消失越嚴重。我們固定生成器g來優化判别器d。考察任意一個樣本

看穿機器學習(W-GAN模型)的黑箱

,其對判别器損失函數的貢獻是

看穿機器學習(W-GAN模型)的黑箱

兩邊對

看穿機器學習(W-GAN模型)的黑箱

求導,得到最優判别器函數

看穿機器學習(W-GAN模型)的黑箱

代入生成器損失函數,我們得到所謂的jensen-shannon散度(js)

看穿機器學習(W-GAN模型)的黑箱

在這種情況下(判别器最優),如果

看穿機器學習(W-GAN模型)的黑箱

的支撐集合(support)交集為零測度,則生成器的損失函數恒為0,梯度消失。

改進 本質上,js散度給出了機率分布

看穿機器學習(W-GAN模型)的黑箱

之間的差異程度,亦即機率分布間的度量。我們可以用其他的度量來替換js散度。wasserstein距離就是一個好的選擇,因為即便

看穿機器學習(W-GAN模型)的黑箱

的支撐集合(support)交集為零測度,它們之間的wasserstein距離依然非零。這樣,我們就得到了wasserstein gan的模式【1】【2】。wasserstein距離的好處在于即便

看穿機器學習(W-GAN模型)的黑箱

兩個分布之間沒有重疊,wasserstein距離依然能夠度量它們的遠近。

為此,我們引入最優傳輸的幾何理論(optimal mass transportation),這個理論可視化了w-gan的關鍵概念,例如機率分布,機率生成模型(生成器),wasserstein距離。更為重要的,這套理論中,所有的概念,原理都是透明的。例如,對于機率生成模型,理論上我們可以用最優傳輸的架構取代深度神經網絡來構造生成器,進而使得黑箱透明。

最優傳輸理論梗概

給定歐氏空間中的一個區域

看穿機器學習(W-GAN模型)的黑箱

,上面定義有兩個機率測度

看穿機器學習(W-GAN模型)的黑箱

看穿機器學習(W-GAN模型)的黑箱

,滿足

看穿機器學習(W-GAN模型)的黑箱

,

我們尋找一個區域到自身的同胚映射(diffeomorphism),

看穿機器學習(W-GAN模型)的黑箱

, 滿足兩個條件:保持測度和極小化傳輸代價。

保持測度 對于一切波萊爾集

看穿機器學習(W-GAN模型)的黑箱
看穿機器學習(W-GAN模型)的黑箱

換句話說映射t将機率分布

看穿機器學習(W-GAN模型)的黑箱

映射成了機率分布

看穿機器學習(W-GAN模型)的黑箱

,記成 

看穿機器學習(W-GAN模型)的黑箱

。直覺上,自映射

看穿機器學習(W-GAN模型)的黑箱

,帶來體積元的變化,是以改變了機率分布。我們用

看穿機器學習(W-GAN模型)的黑箱
看穿機器學習(W-GAN模型)的黑箱

來表示機率密度函數,用

看穿機器學習(W-GAN模型)的黑箱

來表示映射的雅克比矩陣(jacobian matrix),那麼保持測度的微分方程應該是:

看穿機器學習(W-GAN模型)的黑箱
看穿機器學習(W-GAN模型)的黑箱

這被稱為是雅克比方程(jacobian equation)。

最優傳輸映射 自映射

看穿機器學習(W-GAN模型)的黑箱

的傳輸代價(transportation cost)定義為

看穿機器學習(W-GAN模型)的黑箱

在所有保持測度的自映射中,傳輸代價最小者被稱為是最優傳輸映射(optimal mass transportation map),亦即:

看穿機器學習(W-GAN模型)的黑箱

最優傳輸映射的傳輸代價被稱為是機率測度

看穿機器學習(W-GAN模型)的黑箱

和機率測度

看穿機器學習(W-GAN模型)的黑箱

之間的wasserstein距離,記為

看穿機器學習(W-GAN模型)的黑箱

在這種情形下,brenier證明存在一個凸函數

看穿機器學習(W-GAN模型)的黑箱

,其梯度映射

看穿機器學習(W-GAN模型)的黑箱

就是唯一的最優傳輸映射。這個凸函數被稱為是brenier勢能函數(brenier potential)。

由jacobian方程,我們得到brenier勢滿足蒙日-安培方程,梯度映射的雅克比矩陣是brenier勢能函數的海森矩陣(hessian matrix),

看穿機器學習(W-GAN模型)的黑箱

蒙日-安培方程解的存在性、唯一性等價于經典的凸幾何中的亞曆山大定理(alexandrov theorem)。

看穿機器學習(W-GAN模型)的黑箱

圖2. 亞曆山大定理。

亞曆山大定理  如圖2所示,給定平面凸區域

看穿機器學習(W-GAN模型)的黑箱

,考察一個開放的凸多面體

看穿機器學習(W-GAN模型)的黑箱

,標明一個面

看穿機器學習(W-GAN模型)的黑箱
看穿機器學習(W-GAN模型)的黑箱

的法向量記為

看穿機器學習(W-GAN模型)的黑箱
看穿機器學習(W-GAN模型)的黑箱

的投影和

看穿機器學習(W-GAN模型)的黑箱

相交的面積記為

看穿機器學習(W-GAN模型)的黑箱

,則總投影面積滿足

看穿機器學習(W-GAN模型)的黑箱

凸多面體可以被

看穿機器學習(W-GAN模型)的黑箱

确定。亞曆山大定理對任意維凸多面體都成立。

後面,我們可以看到,這個凸多面體就是brenier勢能函數,其梯度映射将一個機率分布

看穿機器學習(W-GAN模型)的黑箱

映到另外一個機率分布

看穿機器學習(W-GAN模型)的黑箱

,并且這兩個機率分布之間的wasserstein 距離對偶于此凸多面體決定的體積。理論上,這個凸多面體可以作為w-gan模型中的生成器g。

w-gan中關鍵概念可視化

wasserstein-gan模型中,關鍵的概念包括機率分布(機率測度),機率測度間的最優傳輸映射(生成器),機率測度間的wasserstein距離。下面,我們詳細解釋每個概念所對應的構造方法,和相應的幾何意義。

機率分布 gan模型中有兩個至關重要的機率分布(probability measure),一個是真實資料的機率分布

看穿機器學習(W-GAN模型)的黑箱

,一個是生成資料的機率分布

看穿機器學習(W-GAN模型)的黑箱

。另外,生成器的輸入随機變量,滿足标準機率分布(高斯、均勻分布)。

看穿機器學習(W-GAN模型)的黑箱

    圖3. 由保角變換(conformal mapping)誘導的圓盤上機率測度。

機率測度可以看成是一種推廣的面積(或者體積)。我們可以用幾何變換随意構造一個機率測度。如圖3所示,我們用三維掃描器擷取一張人臉曲面,那麼人臉曲面上的面積就是一個機率測度。我們縮放變換人臉曲面,使得總曲面等于

看穿機器學習(W-GAN模型)的黑箱

。然後,我們用保角變換将人臉曲面映射到平面圓盤。如圖3所示,保角變換将人臉曲面上的無窮小圓映到平面上的無窮小圓,但是,小圓的面積發生了變化。每對小圓的面積比率定義了平面圓盤上的機率密度函數。

我們可以将以上的描述嚴格化。人臉曲面記為

看穿機器學習(W-GAN模型)的黑箱

,其上具有黎曼度量

看穿機器學習(W-GAN模型)的黑箱

。平面圓盤記為

看穿機器學習(W-GAN模型)的黑箱

,平面坐标為

看穿機器學習(W-GAN模型)的黑箱

,平面的歐氏度量為

看穿機器學習(W-GAN模型)的黑箱

。保角映射記為

看穿機器學習(W-GAN模型)的黑箱

看穿機器學習(W-GAN模型)的黑箱

,這裡面積變換率函數

看穿機器學習(W-GAN模型)的黑箱

給出了機率密度函數。

看穿機器學習(W-GAN模型)的黑箱

誘導了圓盤

看穿機器學習(W-GAN模型)的黑箱

上的一個機率測度

看穿機器學習(W-GAN模型)的黑箱
看穿機器學習(W-GAN模型)的黑箱

圖4. 兩個機率測度之間的最優傳輸映射。

最優傳輸映射 圓盤上本來有均勻分布

看穿機器學習(W-GAN模型)的黑箱

,又有保角變換誘導的機率分布

看穿機器學習(W-GAN模型)的黑箱

,則存在唯一的最優傳輸映射

看穿機器學習(W-GAN模型)的黑箱

。圖4顯示了這個映射

看穿機器學習(W-GAN模型)的黑箱

,中間幀到右幀的映射就是最優傳輸映射。我們看到,鼻尖周圍的區域被壓縮,機率密度提高。

看穿機器學習(W-GAN模型)的黑箱

圖5. 離散最優傳輸。

離散最優傳輸映射 最優傳輸映射的數值計算非常幾何化,是以可以直接被可視化。我們将目标機率測度離散化,表示成一族離散點,

看穿機器學習(W-GAN模型)的黑箱

;每點被賦予一個狄拉克測度,

看穿機器學習(W-GAN模型)的黑箱
看穿機器學習(W-GAN模型)的黑箱

。然後,我們求得機關圓盤的一個胞腔分解,

看穿機器學習(W-GAN模型)的黑箱

,每個胞腔

看穿機器學習(W-GAN模型)的黑箱

映到相應的目标點

看穿機器學習(W-GAN模型)的黑箱
看穿機器學習(W-GAN模型)的黑箱

。映射保持機率測度,胞腔的面積等于目标測度,

看穿機器學習(W-GAN模型)的黑箱

同時極小化傳輸代價,

看穿機器學習(W-GAN模型)的黑箱
看穿機器學習(W-GAN模型)的黑箱

圖6. 離散brenier勢能函數,離散最優傳輸映射。

離散brenier勢能 離散最優傳輸映射是離散brenier勢能函數的梯度映射。對于每一個目标離散點

看穿機器學習(W-GAN模型)的黑箱

,我們構造一個平面 

看穿機器學習(W-GAN模型)的黑箱

,這裡平面的截距

看穿機器學習(W-GAN模型)的黑箱

是未知變量。這些平面的上包絡(upper envelope)構成一個開放的凸多面體,恰為離散brenier勢能函數

看穿機器學習(W-GAN模型)的黑箱

的圖(graph),

看穿機器學習(W-GAN模型)的黑箱

圖6左側顯示了離散briener勢能函數。凸多面體在平面上的投影構成了平面的胞腔分解,凸多面體的每個面

看穿機器學習(W-GAN模型)的黑箱

被映成了一個胞腔

看穿機器學習(W-GAN模型)的黑箱

;每個面

看穿機器學習(W-GAN模型)的黑箱

的梯度都是

看穿機器學習(W-GAN模型)的黑箱

,是以brenier勢能函數的梯度映射就是

看穿機器學習(W-GAN模型)的黑箱

根據保測度性質,每個胞腔

看穿機器學習(W-GAN模型)的黑箱

的面積應該等于指定面積

看穿機器學習(W-GAN模型)的黑箱

。由此,我們調節平面的截距

看穿機器學習(W-GAN模型)的黑箱

以滿足這個限制。根據亞曆山大定理,這種截距存在,并且本質上唯一。

離散wasserstein距離 我們和丘成桐先生建立了變分法來求取平面的截距

看穿機器學習(W-GAN模型)的黑箱

。給定截距向量

看穿機器學習(W-GAN模型)的黑箱

,平面族為

看穿機器學習(W-GAN模型)的黑箱

,其上包絡構成的briener勢能函數為 

看穿機器學習(W-GAN模型)的黑箱

, 上包絡的投影生成了平面的胞腔分解

看穿機器學習(W-GAN模型)的黑箱

, 胞腔的面積記為

看穿機器學習(W-GAN模型)的黑箱

。我們定義的能量為,

看穿機器學習(W-GAN模型)的黑箱

這個能量在子空間

看穿機器學習(W-GAN模型)的黑箱

 上是嚴格凹的,其唯一的全局最大點就給出了滿足保測度條件的截距。這個能量的非線性項,實際上是上包絡截出的柱體體積,

看穿機器學習(W-GAN模型)的黑箱

圖7給出了柱體體積的可視化,柱體體積

看穿機器學習(W-GAN模型)的黑箱

是凸函數。

看穿機器學習(W-GAN模型)的黑箱

圖7. 離散brenier勢能函數的圖截出的柱體體積

看穿機器學習(W-GAN模型)的黑箱

體積函數

看穿機器學習(W-GAN模型)的黑箱

和wasserstein距離之間相差一個勒讓德變換(legendre transformation)。勒讓德變換非常幾何化,我們可以将其可視化。給定一個定義在實數軸上的二階光滑凸函數

看穿機器學習(W-GAN模型)的黑箱

,其圖

看穿機器學習(W-GAN模型)的黑箱

是一條凸曲線,這條凸曲線由其所有的切線包絡而成。如果,在任意一點

看穿機器學習(W-GAN模型)的黑箱

,函數的切線的斜率為y,則此切線的截距滿足

看穿機器學習(W-GAN模型)的黑箱

這被稱為是函數

看穿機器學習(W-GAN模型)的黑箱

的勒讓德變換。

看穿機器學習(W-GAN模型)的黑箱

以切線的斜率為參數,以切線的截距為函數值。

看穿機器學習(W-GAN模型)的黑箱

圖8.凸函數的圖像由其切線包絡而成,切線集合被表示成原函數的勒讓德對偶。

因為

看穿機器學習(W-GAN模型)的黑箱

的凸性,映射

看穿機器學習(W-GAN模型)的黑箱

是微分同胚,記為

看穿機器學習(W-GAN模型)的黑箱

。那麼,原函數和勒讓德變換後的函數滿足關系:

看穿機器學習(W-GAN模型)的黑箱

這裡c,d是常數。原函數和其勒讓德變換的直覺圖解由圖9給出。我們在xy-平面上畫出曲線

看穿機器學習(W-GAN模型)的黑箱

,曲線下面的面積是

看穿機器學習(W-GAN模型)的黑箱

,曲線上面的面積是勒讓德變換

看穿機器學習(W-GAN模型)的黑箱
看穿機器學習(W-GAN模型)的黑箱

圖9. 圖解勒讓德變換。

勒讓德變換的幾何圖景對任意維都對。我們下面來考察體積函數

看穿機器學習(W-GAN模型)的黑箱

的勒讓德變換

看穿機器學習(W-GAN模型)的黑箱

。根據定義,

看穿機器學習(W-GAN模型)的黑箱

假如我們變動截距

看穿機器學習(W-GAN模型)的黑箱

,或者等價地變動胞腔面積

看穿機器學習(W-GAN模型)的黑箱

,考察兩個胞腔交界處

看穿機器學習(W-GAN模型)的黑箱
看穿機器學習(W-GAN模型)的黑箱

p本來屬于

看穿機器學習(W-GAN模型)的黑箱

,變化後屬于

看穿機器學習(W-GAN模型)的黑箱

,所有這種點的總面積為

看穿機器學習(W-GAN模型)的黑箱

。則為wasserstein距離帶來的變化是:

看穿機器學習(W-GAN模型)的黑箱

是以,總的wasserstein距離的變化是

看穿機器學習(W-GAN模型)的黑箱

由此我們看到wasserstein距離等于

看穿機器學習(W-GAN模型)的黑箱

其非線性部分是柱體積的勒讓德變換。

總結

通過以上讨論,我們看到給定兩個機率分布

看穿機器學習(W-GAN模型)的黑箱

,則存在唯一的一個凸函數(brenier 勢函數)

看穿機器學習(W-GAN模型)的黑箱
看穿機器學習(W-GAN模型)的黑箱

把一個機率分布

看穿機器學習(W-GAN模型)的黑箱

映成了另外一個機率分布。這個最優傳輸映射的傳輸代價就給出了兩個機率分布之間的wasserstein距離。brenier勢能函數,wasserstein距離都有明晰的幾何解釋。

在wasserstein-gan模型中,通常生成器和判别器是用深度神經網絡來實作的。根據最優傳輸理論,我們可以用briener勢函數來代替深度神經網絡這個黑箱,進而使得整個系統變得透明。在另一層面上,深度神經網絡本質上是在訓練機率分布間的傳輸映射,是以有可能隐含地在學習最優傳輸映射,或者等價地brenier勢能函數。對這些問題的深入了解,将有助于我們看穿黑箱。

看穿機器學習(W-GAN模型)的黑箱

圖10. 基于二維最優傳輸映射計算的曲面保面積參數化(area preserving parameterization),蘇政宇作。

看穿機器學習(W-GAN模型)的黑箱
看穿機器學習(W-GAN模型)的黑箱

圖11. 基于三維最優傳輸映射計算的保體積參數化 (volume preserving parameterization),蘇科華作。

原文釋出時間為:2017-02-19

本文來自雲栖社群合作夥伴“大資料文摘”,了解相關資訊可以關注“bigdatadigest”微信公衆号

繼續閱讀