天天看點

你所不知道的潛藏在Batch Normalization 中的細節

作者:機器學習搬運工

Batch Normalization 在深度學習網絡設計過程中是一個較為常用技巧,如著名的ResNet 算法,Batch Normalization 貫穿在始終,具體如下:

ResNet 基本結構,其中

你所不知道的潛藏在Batch Normalization 中的細節

BasicBlock應用與ResNet-18、ResNet-34模型,BottleNeck應用與ResNet-50、ResNet-101、ResNet-152模型

Batch Normalization 之是以有效,分析原因主要有如下兩點:

1.通過對隐藏層各神經元的輸入做類似的标準化處理,提高神經網絡訓練速度

2.可以使前面層的權重變化對後面層造成的影響減小,整體網絡更加健壯

關于第一點,對資料批量縮放之後,能加快網絡收斂速度;關于第二點,如果實際應用樣本和訓練樣本的資料分布不同我們稱發生了「Covariate Shift」。這種情況下,一般要對模型進行重新訓練。Batch Normalization 的作用就是減小 Covariate Shift 所帶來的影響,讓模型變得更加健壯,魯棒性(Robustness)更強。

由于每個 Mini-Batch 而非整個資料集上計算均值和方差,隻由這一小部分資料估計得出的均值和方差會有一些噪聲,類似于 Dropout,這種噪聲會使得神經元不會再特别依賴于任何一個輸入特征;減小過拟合。吳恩達老師也提醒大家,不要将 Batch Normalization 作為正則化的手段,而是當作加速學習的方式。正則化隻是一種非期望的副作用。

下面給出Batch Normalization 計算方式:

你所不知道的潛藏在Batch Normalization 中的細節

Bach Normalization 算法

公式中λ,β 是待學習參數;

上面部分可能是絕大多數人的認知,Batch Normalization 潛藏在其中的細節問題是在測試(推理)時候,均值和方差如何選取?

《Batch Normalization: Accelerating Deep Network Training b y Reducing Internal Covariate Shift》 paper 給出的計算方式如下:

你所不知道的潛藏在Batch Normalization 中的細節

其中的均值是每個batch的均值的均值,方差是每個batch的無偏估計量。但是在pytorch具體實作是采用以上所說的滑動平均值方法計算的,pytorch 中滑動均值和方差采取了權重平均的方式,計算如下:

你所不知道的潛藏在Batch Normalization 中的細節

設定α權重,上一個batch 的均值μ,δ 與目前均值進行權重平均,得到新的μ,δ

是以最後一旦整個訓練階段完成,BN層中的所有參數也就固定下來,然後直接用于測試推理.

總結:本文介紹了Batch Normalization 的在深度學習神經網絡設計中的作用,從論文中給出了相關計算方式,但是很多人忽略的細節問題是推理過程中均值和方差是如何确定的,論文中和實際架構中的方式也有一定的差别,本文重點闡述了這個問題,希望得到大家的重視。

繼續閱讀