天天看點

DIM(Learning deep representations by mutual information estimation and maximization)

論文:https://arxiv.org/pdf/1808.06670.pdf

摘要:

許多表示學習隻使用已探索過的資料空間(稱為像素級别),當一小部分資料十分關心語義級别時,該表示學習将不利于訓練。論文提出了無監督表示學習,直接學習和估計資訊内容,統計或結構限制。論文最大化輸入資訊和進階特征向量之間的互資訊與通過對抗比對先驗分布來控制表示學習的特征。

介紹:

人工智能單元在預測和規劃時不應該停留在像素級别或者傳感器級别,而是應該在抽象表示級别。像素級别的非監督機器學習可以在不捕捉語義資訊時表現的非常出色,但是它們并不是好的表示。解決學習一個訓練目标的表示而不适用提前定義好的輸入,一個簡單解決辦法是直接訓練表示學習的函數,最大化輸入和輸出之間的互資訊。在論文 MINE https://arxiv.org/pdf/1801.04062.pdf 中,提供了高效計算高次元神經網絡輸入輸出之間互資訊的計算解決方案。直接最大化輸入和表示之間的互資訊并不能有效的學習有用的表示資訊,但是最大化表示和當地輸入之間的平均互資訊可以極大的提升表示品質。除了互資訊外,表示學習的特征,比如結構也非常重要,論文結合最大化互資訊和比對先驗(類似與AAE算法)來達到好的表現。論文主要貢獻有:① 提出 DIM 算法,使用互資訊神經網絡估計明确的最大化輸入和已經學習的高次元表示之間的互資訊。② 最大化互資訊可以全局或者當地資訊優先,協調使得在表示和分類或者重建任務中更高效。③ 使用對抗學習限制表示資訊的先驗分布的統計特性 ④ 引入了兩個提升表示學習品質的方法,一個是MINE 另一個是 https://arxiv.org/pdf/1710.05050.pdf 。

生成模型:

生成模型依賴重建和對抗,重建誤差與互資訊的聯系可以如下表示

DIM(Learning deep representations by mutual information estimation and maximization)

其中,X 和 Y 分别代表随機變量的輸入和表示,而

DIM(Learning deep representations by mutual information estimation and maximization)

表示重建誤差,

DIM(Learning deep representations by mutual information estimation and maximization)

表示編碼器的邊緣分布的熵。在雙向對抗模型(bi-directional adversarial models)訓練編碼器和解碼器來比對表示的聯合分布,這樣操作會增加邊緣分布的熵或者減小重建誤差。在 GAN 裡面采用生成和對抗模型,辨識器來辨識真假圖檔時需要很高的互資訊值,但是在高次元情境下,學習生成模型非常困難。同時,圖檔中不是所有資訊都很重要,有時候一張圖檔隻有一小部分的特征就可以表示整個圖檔的重要資訊。

免解碼器模型:

依賴最大化似然函數的算法(arXiv:1410.8516,2014),但該算法為了成立一個似然目标函數嚴格限制了編碼器和輸出空間。深度聚類算法(Unsupervised deep embedding for clustering analysis)在非監督聚類中表現優異,但是用途不廣闊。NAT算法将表示作為一個監督學習中的噪聲目标來進行非監督學習,不需要生成模型,但是需要一個推斷機制将輸入和噪聲排列起來。NAT算法需要大量的采樣,并需要訓練先驗分布,同時NAT算法如何影響輸入資料的大小和表示的次元并不清楚。

互資訊估計:

INFOMAX 主張最大化輸入和輸出之間的互資訊。MINE 算法學習連續變量的神經網絡估計的互資訊,通過最大化編碼器的輸入和輸出來限制變量和用于學習更好的生成模型。論文使用 KL 散度,使用層級化輸入的結構來提升表示分類的能力。DIM 使用特征映射對應區域的層級化的采樣,使用 1x1 的卷積來表示當地的小塊區域和全局變量之間的互資訊估計。

DIM:

定義式如下:

DIM(Learning deep representations by mutual information estimation and maximization)
DIM(Learning deep representations by mutual information estimation and maximization)

是一個關于 y 的 Dirac 函數。

如左圖所示為編碼器的示意圖:圖像資訊被編碼為一個卷積神經網絡,卷積過程直到映射 MXM 對應了輸入的 MXM,使用全連接配接整合成一個特征向量,目标是訓練這個神經網絡,這個神經網絡的輸入的相關資訊可在高層特征中抽離出來。

DIM(Learning deep representations by mutual information estimation and maximization)
DIM(Learning deep representations by mutual information estimation and maximization)

如上右圖所示,我們提出一個高維向量 Y 和一個低級别 MXM 的映射通過一個鑒别器來打分,鑒别器由神經網絡,全連接配接網絡組成,假的采樣通過與另一個圖像的相同特征向量結合而描繪出來。

互資訊的估計和最大化:

DIM(Learning deep representations by mutual information estimation and maximization)

其中

DIM(Learning deep representations by mutual information estimation and maximization)

是一個基于參數 w 的鑒别函數,論文同時最大化和估計互資訊

DIM(Learning deep representations by mutual information estimation and maximization)

,如下公式:

DIM(Learning deep representations by mutual information estimation and maximization)

因為編碼器和MINE算法在優化目标函數的時候使用類似的計算方法,是以論文結合了最初的兩種網絡結構:

DIM(Learning deep representations by mutual information estimation and maximization)

論文使用 JSD 散度公式,結合解碼器和MINE的目标函數,得到如下公式:

DIM(Learning deep representations by mutual information estimation and maximization)

其中,y(x)是一個更級别的表示,x' 是與 y 不相關的另外一個輸入,

DIM(Learning deep representations by mutual information estimation and maximization)

,JSD散度公式更适合本論文最大化互資訊,① JSD 的上屆 log2,在計算時不會産生特别大的數 ② JSD 的梯度是無偏的。

最大化當地互資訊:

上述公式是最大化輸入和輸出的互資訊的,但是根本上我們的任務并不需要那麼做,比如當地像素的噪聲,如果最終的目标是分類,那麼這個表示就不太優異。為了保證表示模型能夠适應分類任務,我們最大化進階表示和當地小範圍圖像的平均互資訊。因為相同的表示鼓勵更高的互資訊,某些區域的資料會共用了一部分資料,解碼器可以選擇輸入資訊的類型,但是當解碼器通過某些特定輸入資訊時,不會因為其他的區域不包含上述噪聲而增大互資訊,這将使得解碼器更傾向于輸入中共享的資訊。

如下圖所示:最大化當地特征和進階特征向量之間的互資訊,論文将圖像編碼成一個映射,該映射包含資料的一些結構特征,并且将該映射整合成一個全局特征向量(在上圖可以看到)這個特征向量在每一個區域都連結低級特征映射,一個1x1的卷積鑒别器用來給真實圖檔和假圖檔打分,假圖檔是通過另外一張圖像生成的映射而生成的。

DIM(Learning deep representations by mutual information estimation and maximization)

公式轉化如下:

DIM(Learning deep representations by mutual information estimation and maximization)
DIM(Learning deep representations by mutual information estimation and maximization)
DIM(Learning deep representations by mutual information estimation and maximization)

論文提出,當地互資訊最大化雖然引入了真實和虛假圖檔的機率,但是并沒有顯著提高效果。

比對表示與先驗分布:

好的表示學習應該是簡潔的、獨立的、無糾纏的(disentangled)或者獨立可控的。如圖所示:訓練解碼器是為了疑惑鑒别器,使之不能分辨出真的圖檔和假的圖檔。真的采樣取自先驗分布,假的采樣取自編碼器。

DIM(Learning deep representations by mutual information estimation and maximization)

公式如下:訓練編碼器最小化散度

DIM(Learning deep representations by mutual information estimation and maximization)

将全局互資訊,區域互資訊和先驗分布比對加到一起,得到如下公式:

DIM(Learning deep representations by mutual information estimation and maximization)

w1 和 w2 分别是鑒别器全局和區域目标的參數,α、β 和 γ 是超參數。

繼續閱讀