天天看點

Batch Normalization的一些個人了解

簡單說一說Batch Normalization的一些個人了解:

1、要說batch normalization不得不先說一下梯度消失和梯度爆炸問題

梯度消失一是容易出現在深層網絡中,二是采用了不合适的損失函數也較容易發生梯度消失,比如sigmoid;

梯度爆炸一般出現在深層網絡和權值初始化值太大的情況。

考慮這樣一個簡單的三隐層的全連接配接網絡

Batch Normalization的一些個人了解

我們假設每一層網絡激活後的輸出為fi(x),其中i表示第i層, x代表第i層的輸入,也就是第i−1層的輸出,f是激活函數,那麼得出fi+1=f(fi∗wi+1+bi+1) ,偏置在梯度消失和梯度爆炸中影響較小,是以我們可以簡單記為fi+1=f(fi∗wi+1)。

如果要更新第二隐層的權值資訊,根據鍊式求導法則,更新梯度資訊:

Batch Normalization的一些個人了解

很容易看出來∂f2 /∂w2即為第二隐藏層的輸入,∂f4 / ∂f3就是對激活函數求導。如果此部分大于1,那麼層數增多的時候,最終的求出的梯度更新資訊将以指數形式增加,即發生梯度爆炸,如果此部分小于1,那麼随着層數增多,求出的梯度更新資訊将會以指數形式衰減,即發生了梯度消失(注:在深層網絡中梯度消失發生的常見一點,而梯度爆炸比較少見)。

如下圖所示,是sigmoid函數的導數資訊,可見sigmoid函數的導數最大為0.25,是以使用sigmoid函數,網絡過深時較容易發生梯度消失

Batch Normalization的一些個人了解

梯度消失和梯度爆炸問題也說明一個問題:随着網絡深度的增加,網絡的輸出在輸入到激活函數中之前,往往會趨于分布在較大的值附近或較小的值附近,而BN正是可以較好的解決這個問題

2、介紹batch normalization

因為深層神經網絡在做非線性變換前的激活輸入值(就是x=WU+B,U是輸入,x是激活函數的輸入)随着訓練過程加深中,其分布逐漸發生偏移或者變動,之是以訓練收斂慢一般是整體分布逐漸往非線性函數的取值區間的上下限兩端靠近(對于Sigmoid函數來說,意味着激活輸入值WU+B是大的負值或正值),是以這導緻激活單元飽和,即反向傳播時低層神經網絡的梯度消失,這是訓練深層神經網絡收斂越來越慢的本質原因。

而BN就是通過一定的規範化手段,把每層神經網絡任意神經元的輸入值的分布強行拉回到均值為0方差為1的标準正态分布,其實就是把越來越偏的分布強制拉回比較标準的分布,這樣使得激活函數輸入值落在非線性函數對輸入比較敏感的區域,這樣輸入的小變化就會導緻激活函數較大的變化,也就是讓梯度變大,避免産生梯度消失問題,而且梯度變大也意味着學習收斂速度快,能大大加快訓練速度。

繼續閱讀