天天看點

批标準化(BatchNorm)

注:本文部分參考自以下文章:

深入了解Batch Normalization批标準化

李理:卷及神經網絡之Batch Normalization的原理及實作

原文連結:《Batch Normalizaion: Accelerating Deep Network Training by Reducing Internal Convariate Shift》

翻譯、導讀等推薦:1、2

1. BN目的
機器學習領域有個很重要的假設:獨立同分布(IID,Independent Identically Distributed)假設,就是假設訓練資料和測試資料是滿足相同分布的,這是通過訓練資料獲得的模型能夠在測試集獲得好的效果的一個基本保障。那BatchNorm的作用是什麼呢?BatchNorm就是在深度神經網絡訓練過程中使得每一層神經網絡的輸入保持相同分布的。
2. 内部協變量漂移(Internal Covariate Shift)

When the input distribution to a learning system changes, it is said to experience covariate shift

covariate shift問題是由于訓練資料的領域模型 Ps(X) 和測試資料的 Pt(X) 分布不一緻造成的,這裡的下标s和t是source和target的縮寫,代表訓練和測試。

Mini-Batch SGD vs SGD(one sample):梯度更新方向準确、并行計算速度快,但需要調節很多超參數(學習率、初值等)。

各層權重參數嚴重影響每層的輸入,輸入的小變動随着層數加深不斷放大。這就導緻,各層輸入分布的變動導緻模型需要不停地去拟合新的分布。

于是,BN希望通過每層的輸入均值、方差進行規範化,使輸入分布一緻

3. BN的思想
對于每個隐層神經元,把逐漸向非線性函數映射後向取值區間極限飽和區靠攏的輸入分布強制拉回到均值為0方差為1的比較标準的正态分布,使得非線性變換函數的輸入值落入對輸入比較敏感的區域,以此避免梯度消失問題。

如果我們能保證每次minibatch時每個層的輸入資料都是均值0方差1,那麼就可以解決這個問題。是以我們可以加一個batch normalization層對這個minibatch的資料進行處理。但是這樣也帶來一個問題,把某個層的輸出限制在均值為0方差為1的分布會使得網絡的表達能力變弱。是以作者又給batch normalization層進行一些限制的放松,給它增加兩個可學習的參數 β 和 γ ,對資料進行縮放和平移,平移參數 β 和縮放參數 γ 是學習出來的。

備注:

由于sigmoid這類激活函數,隻有在0左右的鄰域處,導數較大;是以,BN政策在一定程度上可以保證激活函數的梯度一直較大,這避免了梯度消失問題;并且梯度夠大表明訓練速度較快。當然,由于這使得x大多落在sigmoid的線性區,而違背了當初使用sigmoid非線性變換的初衷,進而降低了表達能力,上述的參數β和γ在一定程度上可以解決此問題。

批标準化(BatchNorm)
4. BN 的預測(Inference)

雖然訓練過程可以根據 Mini-Batch來獲得統計量,但是預測時隻有一個資料,無從計算合理的均值和方差。解決辦法是,使用訓練的所有資料(population)的均值和方差(用每個mini-batch的統計量計算得來即可)。

有了均值和方差,每個隐含層也有訓練好的β和γ,就可以在預測過程中進行BN操作了

5. BN 的優勢

論文中将Batch Normalization的作用說得突破天際,好似一下解決了所有問題,下面就來一一列舉一下:

  

(1) 可以使用更高的學習率。如果每層的scale不一緻,實際上每層需要的學習率是不一樣的,同一層不同次元的scale往往也需要不同大小的學習率,通常需要使用最小的那個學習率才能保證損失函數有效下降,Batch Normalization将每層、每維的scale保持一緻,那麼我們就可以直接使用較高的學習率進行優化。

(2) 移除或使用較低的dropout。 dropout是常用的防止overfitting的方法,而導緻overfit的位置往往在資料邊界處,如果初始化權重就已經落在資料内部,overfit現象就可以得到一定的緩解。論文中最後的模型分别使用10%、5%和0%的dropout訓練模型,與之前的40%-50%相比,可以大大提高訓練速度。

(3) 降低L2權重衰減系數。 還是一樣的問題,邊界處的局部最優往往有幾維的權重(斜率)較大,使用L2衰減可以緩解這一問題,現在用了Batch Normalization,就可以把這個值降低了,論文中降低為原來的5倍。

(4) 取消Local Response Normalization層。 由于使用了一種Normalization,再使用LRN就顯得沒那麼必要了。而且LRN實際上也沒那麼work。

(5) 減少圖像扭曲的使用。 由于現在訓練epoch數降低,是以要對輸入資料少做一些扭曲,讓神經網絡多看看真實的資料。

推薦閱讀:深入解讀Inception V2之Batch Normalization