天天看點

蒸餾神經網絡(Distill the Knowledge in a Neural Network)

本文是閱讀Hinton 大神在2014年NIPS上一篇論文:蒸餾神經網絡的筆記,特此說明。此文讀起來很抽象,大篇的論述,鮮有公式和圖表。但是鑒于和我的研究方向:神經網絡的壓縮十分相關,是以決定花氣力好好了解一下。 

1、Introduction  

文章開篇用一個比喻來引入網絡蒸餾:

    昆蟲作為幼蟲時擅于從環境中汲取能量,但是成長為成蟲後确是擅于其他方面,比如遷徙和繁殖等。

同理神經網絡訓練階段從大量資料中擷取網絡模型,訓練階段可以利用大量的計算資源且不需要實時響應。然而到達使用階段,神經網絡需要面臨更加嚴格的要求包括計算資源限制,計算速度要求等等。

由昆蟲的例子我們可以這樣了解神經網絡:一個複雜的網絡結構模型是若幹個單獨模型組成的集合,或者是一些很強的限制條件下(比如dropout率很高)訓練得到的一個很大的網絡模型。一旦複雜網絡模型訓練完成,我們便可以用另一種訓練方法:“蒸餾”,把我們需要配置在應用端的縮小模型從複雜模型中提取出來。

      “蒸餾”的難點在于如何縮減網絡結構但是把網絡中的知識保留下來。知識就是一幅将輸入向量導引至輸出向量的地圖。做複雜網絡的訓練時,目标是将正确答案的機率最大化,但這引入了一個副作用:這種網絡為所有錯誤答案配置設定了機率,即使這些機率非常小。 

      我們将複雜模型轉化為小模型時需要注意保留模型的泛化能力,一種方法是利用由複雜模型産生的分類機率作為“軟目标”來訓練小模型。在轉化階段,我們可以用同樣的訓練集或者是另外的“轉化”訓練集。當複雜模型是由簡單模型複合而成時,我們可以用各自的機率分布的代數或者幾何平均數作為“軟目标”。當“軟目标的”熵值較高時,相對“硬目标”,它每次訓練可以提供更多的資訊和更小的梯度方差,是以小模型可以用更少的資料和更高的學習率進行訓練。 

    像MNIST這種任務,複雜模型可以給出很完美的結果,大部分資訊分布在小機率的軟目标中。比如一張2的圖檔被認為是3的機率為0.000001,被認為是7的機率是0.000000001。Caruana用logits(softmax層的輸入)而不是softmax層的輸出作為“軟目标”。他們目标是是的複雜模型和小模型分别得到的logits的平方差最小。而我們的“蒸餾法”:第一步,提升softmax表達式中的調節參數T,使得複雜模型産生一個合适的“軟目标”  第二步,采用同樣的T來訓練小模型,使得它産生相比對的“軟目标”

   “轉化”訓練集可以由未打标簽的資料組成,也可以用原訓練集。我們發現使用原訓練集效果很好,特别是我們在目标函數中加了一項之後。這一項的目的是是的小模型在預測實際目标的同時盡量比對“軟目标”。要注意的是,小模型并不能完全無誤的比對“軟目标”,而正确結果的犯錯方向是有幫助的。

2、Distillation 

蒸餾神經網絡(Distill the Knowledge in a Neural Network)
蒸餾神經網絡(Distill the Knowledge in a Neural Network)
蒸餾神經網絡(Distill the Knowledge in a Neural Network)

    softmax層的公式如下: 

蒸餾神經網絡(Distill the Knowledge in a Neural Network)
蒸餾神經網絡(Distill the Knowledge in a Neural Network)
蒸餾神經網絡(Distill the Knowledge in a Neural Network)
蒸餾神經網絡(Distill the Knowledge in a Neural Network)

      T就是調節參數,一般設為1。T越大,分類的機率分布越“軟” 

   “蒸餾”最簡單的形式就是:以從複雜模型得到的“軟目标”為目标(這時T比較大),用“轉化”訓練集訓練小模型。訓練小模型時T不變仍然較大,訓練完之後T改為1。 

   當“轉化”訓練集中部分或者所有資料都有标簽時,這種方式可以通過一起訓練模型使得模型得到正确的标簽來大大提升效果。一種實作方法是用正确标簽來修正“軟目标”,但是我們發現一種更好的方法是:對兩個目标函數設定權重系數。第一個目标函數是“軟目标”的交叉熵,這個交叉熵用開始的那個比較大的T來計算。第二個目标函數是正确标簽的交叉熵,這個交叉熵用小模型softmax層的logits來計算且T等于1。我們發現當第二個目标函數權重較低時可以得到最好的結果 

蒸餾神經網絡(Distill the Knowledge in a Neural Network)

3、Preliminary experiments on MNIST 

蒸餾神經網絡(Distill the Knowledge in a Neural Network)
蒸餾神經網絡(Distill the Knowledge in a Neural Network)
蒸餾神經網絡(Distill the Knowledge in a Neural Network)

  我的了解:将遷移資料集中的3或者7、8去掉是為了證明小模型也能夠從soft target中學得知識。 

4、Experiments on Speech Recognition 

蒸餾神經網絡(Distill the Knowledge in a Neural Network)
蒸餾神經網絡(Distill the Knowledge in a Neural Network)
蒸餾神經網絡(Distill the Knowledge in a Neural Network)

5、Training ensembles of specialists on very big datasets 

蒸餾神經網絡(Distill the Knowledge in a Neural Network)
蒸餾神經網絡(Distill the Knowledge in a Neural Network)
蒸餾神經網絡(Distill the Knowledge in a Neural Network)
蒸餾神經網絡(Distill the Knowledge in a Neural Network)
蒸餾神經網絡(Distill the Knowledge in a Neural Network)
蒸餾神經網絡(Distill the Knowledge in a Neural Network)
蒸餾神經網絡(Distill the Knowledge in a Neural Network)
蒸餾神經網絡(Distill the Knowledge in a Neural Network)
蒸餾神經網絡(Distill the Knowledge in a Neural Network)

繼續閱讀