天天看點

深度學習通用政策:BN原理詳解以及優勢

今年過年之前,MSRA和Google相繼在ImagenNet圖像識别資料集上報告他們的效果超越了人類水準,下面将分兩期介紹兩者的算法細節。

  這次先講Google的這篇《Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift》,主要是因為這裡面的思想比較有普适性,而且一直答應群裡的人寫一個有關預處理的科普,但一直沒抽出時間來寫。

一、神經網絡中的權重初始化與預處理方法的關系

如果做過dnn的實驗,大家可能會發現在對資料進行預處理,例如白化或者zscore,甚至是簡單的減均值操作都是可以加速收斂的,例如下圖所示的一個簡單的例子:

深度學習通用政策:BN原理詳解以及優勢

  圖中紅點代表2維的資料點,由于圖像資料的每一維一般都是0-255之間的數字,是以資料點隻會落在第一象限,而且圖像資料具有很強的相關性,比如第一個灰階值為30,比較黑,那它旁邊的一個像素值一般不會超過100,否則給人的感覺就像噪聲一樣。由于強相關性,資料點僅會落在第一象限的很小的區域中,形成類似上圖所示的狹長分布。

  而神經網絡模型在初始化的時候,權重W是随機采樣生成的,一個常見的神經元表示為:ReLU(Wx+b) = max(Wx+b,0),即在Wx+b=0的兩側,對資料采用不同的操作方法。具體到ReLU就是一側收縮,一側保持不變。

  随機的Wx+b=0表現為上圖中的随機虛線,注意到,兩條綠色虛線實際上并沒有什麼意義,在使用梯度下降時,可能需要很多次疊代才會使這些虛線對資料點進行有效的分割,就像紫色虛線那樣,這勢必會帶來求解速率變慢的問題。更何況,我們這隻是個二維的示範,資料占據四個象限中的一個,如果是幾百、幾千、上萬維呢?而且資料在第一象限中也隻是占了很小的一部分區域而已,可想而知不對資料進行預處理帶來了多少運算資源的浪費,而且大量的資料外分割面在疊代時很可能會在剛進入資料中時就遇到了一個局部最優,導緻overfit的問題。

  這時,如果我們将資料減去其均值,資料點就不再隻分布在第一象限,這時一個随機分界面落入資料分布的機率增加了多少呢?2^n倍!如果我們使用去除相關性的算法,例如PCA和ZCA白化,資料不再是一個狹長的分布,随機分界面有效的機率就又大大增加了。

  不過計算協方差矩陣的特征值太耗時也太耗空間,我們一般最多隻用到z-score處理,即每一次元減去自身均值,再除以自身标準差,這樣能使資料點在每維上具有相似的寬度,可以起到一定的增大資料分布範圍,進而使更多随機分界面有意義的作用。

二、Batch Normalization

  上一節我們講到對輸入資料進行預處理,減均值->zscore->白化可以逐級提升随機初始化的權重對資料分割的有效性,還可以降低overfit的可能性。我們都知道,現在的神經網絡的層數都是很深的,如果我們對每一層的資料都進行處理,訓練時間和overfit程度是否可以降低呢?Google的這篇論文給出了答案。

1、算法描述

  按照第一章的理論,應當在每一層的激活函數之後,例如ReLU=max(Wx+b,0)之後,對資料進行歸一化。然而,文章中說這樣做在訓練初期,分界面還在劇烈變化時,計算出的參數不穩定,是以退而求其次,在Wx+b之後進行歸一化。因為初始的W是從标準高斯分布中采樣得到的,而W中元素的數量遠大于x,Wx+b每維的均值本身就接近0、方差接近1,是以在Wx+b後使用Batch Normalization能得到更穩定的結果。

       文中使用了類似z-score的歸一化方式:每一次元減去自身均值,再除以自身标準差,由于使用的是随機梯度下降法,這些均值和方差也隻能在目前疊代的batch中計算,故作者給這個算法命名為Batch Normalization。這裡有一點需要注意,像卷積層這樣具有權值共享的層,Wx+b的均值和方差是對整張map求得的,在batch_size * channel * height * width這麼大的一層中,對總共batch_size*height*width個像素點統計得到一個均值和一個标準差,共得到channel組參數。

  在Normalization完成後,Google的研究員仍對數值穩定性不放心,又加入了兩個參數gamma和beta,使得

深度學習通用政策:BN原理詳解以及優勢

       注意到,如果我們令gamma等于之前求得的标準差,beta等于之前求得的均值,則這個變換就又将資料還原回去了。在他們的模型中,這兩個參數與每層的W和b一樣,是需要疊代求解的。文章中舉了個例子,在sigmoid激活函數的中間部分,函數近似于一個線性函數(如下圖所示),使用BN後會使歸一化後的資料僅使用這一段線性的部分(吐槽一下:再乘個2之類的不就行了)。

深度學習通用政策:BN原理詳解以及優勢

       可以看到,在[0.2, 0.8]範圍内,sigmoid函數基本呈線性遞增,甚至在[0.1, 0.9]範圍内,sigmoid函數都是類似于線性函數的,如果隻用這一段,那網絡不就成了線性網絡了麼,這顯然不是大家願意見到的。至于這兩個參數對ReLU起的作用文中沒說,我就不妄自揣摩了哈。

       算法原理到這差不多就講完了,下面是大家 最不喜歡的公式環節了,求均值和方差就不用說了,在BP的時候,我們需要求最終的損失函數對gamma和beta兩個參數的導數,還要求損失函數對Wx+b中的x的導數,以便使誤差繼續向後傳播。求導公式如下:

深度學習通用政策:BN原理詳解以及優勢

  具體的公式推導就不寫了,有興趣的讀者可以自己推一下,主要用到了鍊式法則。

  在訓練的最後一個epoch時,要對這一epoch所有的訓練樣本的均值和标準差進行統計,這樣在一張測試圖檔進來時,使用訓練樣本中的标準差的期望和均值的期望(好繞口)對測試資料進行歸一化,注意這裡标準差使用的期望是其無偏估計:

深度學習通用政策:BN原理詳解以及優勢

2、算法優勢

  論文中将Batch Normalization的作用說得突破天際,好似一下解決了所有問題,下面就來一一列舉一下:    (1) 可以使用更高的學習率 。如果每層的scale不一緻,實際上每層需要的學習率是不一樣的,同一層不同次元的scale往往也需要不同大小的學習率,通常需要使用最小的那個學習率才能保證損失函數有效下降,Batch Normalization将每層、每維的scale保持一緻,那麼我們就可以直接使用較高的學習率進行優化。    (2) 移除或使用較低的dropout。 dropout是常用的防止overfitting的方法,而導緻overfit的位置往往在資料邊界處,如果初始化權重就已經落在資料内部,overfit現象就可以得到一定的緩解。論文中最後的模型分别使用10%、5%和0%的dropout訓練模型,與之前的40%-50%相比,可以大大提高訓練速度。    (3) 降低L2權重衰減系數 。 還是一樣的問題,邊界處的局部最優往往有幾維的權重(斜率)較大,使用L2衰減可以緩解這一問題,現在用了Batch Normalization,就可以把這個值降低了,論文中降低為原來的5倍。    (4) 取消Local Response Normalization層 。 由于使用了一種Normalization,再使用LRN就顯得沒那麼必要了。而且LRN實際上也沒那麼work。    (5) 減少圖像扭曲的使用。 由于現在訓練epoch數降低,是以要對輸入資料少做一些扭曲,讓神經網絡多看看真實的資料。

三、實驗

  這裡我隻在matlab上面對算法進行了仿真,修改了DeepLearnToolbox 裡面的NN模型,代碼如下:

  在前向傳播時,分兩種情況進行讨論:如果是在train過程,就使用目前batch的資料統計均值和标準差,并按照第二章所述公式對Wx+b進行歸一化,之後再乘上gamma,加上beta得到Batch Normalization層的輸出;如果在進行test過程,則使用記錄下的均值和标準差,還有之前訓練好的gamma和beta計算得到結果

[plain]  view plain  copy

  1. if nn.testing  
  2.     nn.a_pre{i} = nn.a{i - 1} * nn.W{i - 1}';  
  3.     norm_factor = nn.gamma{i-1}./sqrt(nn.mean_sigma2{i-1}+nn.epsilon);  
  4.     nn.a_hat{i} = bsxfun(@times, nn.a_pre{i}, norm_factor);  
  5.     nn.a_hat{i} = bsxfun(@plus, nn.a_hat{i}, nn.beta{i-1} -  norm_factor .* nn.mean_mu{i-1});  
  6. else  
  7.     nn.a_pre{i} = nn.a{i - 1} * nn.W{i - 1}';  
  8.     nn.mu{i-1} = mean(nn.a_pre{i});  
  9.     x_mu = bsxfun(@minus,nn.a_pre{i},nn.mu{i-1});  
  10.     nn.sigma2{i-1} = mean(x_mu.^2);  
  11.     norm_factor = nn.gamma{i-1}./sqrt(nn.sigma2{i-1}+nn.epsilon);  
  12.     nn.a_hat{i} = bsxfun(@times, nn.a_pre{i}, norm_factor);  
  13.     nn.a_hat{i} = bsxfun(@plus, nn.a_hat{i}, nn.beta{i-1} -  norm_factor .* nn.mu{i-1});  
  14. end;  

  反向傳播就跟上面那一堆公式一樣啦,注意為了運作效率,盡量使用向量化的代碼,避免使用for循環: [plain]     view plain  copy

  1. d_xhat = bsxfun(@times, d{i}(:,2:end), nn.gamma{i-1});  
  2. x_mu = bsxfun(@minus, nn.a_pre{i}, nn.mu{i-1});  
  3. inv_sqrt_sigma = 1 ./ sqrt(nn.sigma2{i-1} + nn.epsilon);  
  4. d_sigma2 = -0.5 * sum(d_xhat .* x_mu) .* inv_sqrt_sigma.^3;  
  5. d_mu = bsxfun(@times, d_xhat, inv_sqrt_sigma);  
  6. d_mu = -1 * sum(d_mu) -2 .* d_sigma2 .* mean(x_mu);  
  7. d_gamma = mean(d{i}(:,2:end) .* nn.a_hat{i});  
  8. d_beta = mean(d{i}(:,2:end));  
  9. di1 = bsxfun(@times,d_xhat,inv_sqrt_sigma);  
  10. di2 = 2/m * bsxfun(@times, d_sigma2,x_mu);  
  11. d{i}(:,2:end) = di1 + di2 + 1/m * repmat(d_mu,m,1);  

  在訓練的最後一個epoch,要對所有的gamma和beta進行統計,代碼很簡單就不貼了,完整代碼在我的Github上有:https://github.com/happynear/DeepLearnToolbox

1、sigmoid激活函數的過飽和問題

  經測試發現算法對sigmoid激活函數的提升非常明顯,解決了困擾學術界十幾年的sigmoid過飽和的問題,即在深層的神經網絡中,前幾層在梯度下降時得到的梯度過低,導緻深層神經網絡變成了前邊是随機變換,隻在最後幾層才是真正在做分類的問題。

  下面是使用一個10個隐藏層的nn網絡,對mnist進行分類,每層的梯度值:

  使用Batch Normalization前:

[plain]  view plain  copy

  1. epoch:1 iteration:10/300  
  2.  3.23e-07 8.3215e-07 3.3605e-06 1.5193e-05 6.4892e-05 0.00027249 0.0011954 0.006295 0.029835 0.12476 0.38948  
  3. epoch:1 iteration:20/300  
  4.  4.4649e-07 1.3282e-06 5.6753e-06 2.5294e-05 0.00010326 0.00043651 0.0019583 0.0096396 0.040469 0.16142 0.5235  
  5. epoch:1 iteration:30/300  
  6.  4.6973e-07 1.2993e-06 5.3923e-06 2.3111e-05 9.4839e-05 0.00040398 0.0017893 0.0081367 0.037543 0.1544 0.46472  
  7. epoch:1 iteration:40/300  
  8.  4.6986e-07 1.3801e-06 5.677e-06 2.4355e-05 0.00010245 0.00041999 0.0019832 0.0095022 0.043719 0.17696 0.56134  
  9. epoch:1 iteration:50/300  
  10.  4.6964e-07 1.6532e-06 7.2543e-06 3.0731e-05 0.00011805 0.00048795 0.0021705 0.0099466 0.042835 0.17993 0.5319  

  可以看到,最開始的幾層隻有1e-6到1e-7這個量級的梯度,基本上梯度在最後3層就已經飽和了。

  使用Batch Normalization後:

[plain]  view plain  copy

  1. epoch:1 iteration:10/300  
  2.  0.27121 0.15534 0.15116 0.15409 0.15515 0.14542 0.12878 0.13888 0.16607 0.21036 0.76037  
  3. epoch:1 iteration:20/300  
  4.  0.24567 0.15369 0.14169 0.13183 0.1278 0.13904 0.13546 0.12032 0.14332 0.14868 0.54481  
  5. epoch:1 iteration:30/300  
  6.  0.30403 0.16365 0.14119 0.14502 0.13916 0.12851 0.11781 0.11424 0.11082 0.1088 0.39574  
  7. epoch:1 iteration:40/300  
  8.  0.32681 0.19801 0.16792 0.14741 0.13294 0.12805 0.13754 0.12941 0.13288 0.12957 0.50937  
  9. epoch:1 iteration:50/300  
  10.  0.32358 0.17484 0.16367 0.16605 0.17118 0.14703 0.14458 0.12693 0.13928 0.11938 0.3692  

  我第一次看到的時候,就像之前看到ReLU一樣驚豔,終于,sigmoid的飽和問題也得到了解決。不過論文中還有我自己的實驗都表明,sigmoid在分類問題上确實沒有ReLU好用,大概是因為sigmoid的中間部分太“線性”了,不像ReLU一個很大的轉折,在拟合複雜非線性函數的時候可能沒那麼高效,真的是蠻遺憾的。

2、gamma和beta的作用

  在第二章提到,引入gamma和beta兩個參數是為了避免資料隻用sigmoid的線性部分,這裡做了個簡單的測試,将用和不用gamma與beta參數訓練出的網絡的最大/最小激活值顯示出來:

深度學習通用政策:BN原理詳解以及優勢

  可以看到,如果不使用gamma和beta,激活值基本上會在[0.1 0.9]這個近似線性的區域中,這與深度神經網絡所要求的“多層非線性函數逼近任意函數”的要求不符,是以引入gamma和beta還是有必要的,深度網絡會自動決定使用哪一段函數(這是我自己想的,其具體作用歡迎讨論)。

  對于ReLU來說,gamma的作用可能不是很明顯,因為relu是分段”線性“的,對數值進行伸縮并不能影響relu取x還是取0。但beta的作用就很大了,試想一下如果沒有beta,經過batch normalization層的特征,都具有0均值的期望,這樣豈不是強制令ReLU的輸出有一半是0一半非0麼?這與我們的初衷不太相符,我們希望神經網絡自行決定在什麼位置去設定這個門檻值,而不是增加一個如此強的限制。另外,因為這個beta我曾經還鬧了個大笑話,記錄在http://blog.csdn.net/happynear/article/details/46583811,請大家引以為戒。

四、總結

  Batch Normalization的加速作用展現在兩個方面:一是歸一化了每層和每次元的scale,是以可以整體使用一個較高的學習率,而不必像以前那樣遷就小scale的次元;二是歸一化後使得更多的權重分界面落在了資料中,降低了overfit的可能性,是以一些防止overfit但會降低速度的方法,例如dropout和權重衰減就可以不使用或者降低其權重。   截止到目前,還沒有哪個機構宣布重制了論文中的結果,不過歸一化的用處在理論層面就已經有了保證,以後也許歸一化的形式會有所改變,但逐層的歸一化應該會成為一種标準。本部落格文章僅僅給出了歸一化優點的幾何解釋,希望有更多的理論解釋來指導我們使用歸一化層。   就目前來看,争議的重點在于歸一化的位置,還有gamma與beta參數的引入,從理論上分析,論文中的這兩個細節實際上并不符合ReLU的特性:ReLU後,資料分布重新回到第一象限,這時是最應當進行歸一化的;gamma與beta對sigmoid函數确實能起到一定的作用(實際也不如固定gamma=2),但對于ReLU這種分段線性的激活函數,并不存在sigmoid的低scale呈線性的現象。期待更多的理論分析,我自己也會持續跟進這個方向。

五、一些資源

本文所用到的matlab代碼:https://github.com/happynear/DeepLearnToolbox Caffe的BN實作:https://github.com/ducha-aiki/caffe/tree/bn cxxnet的BN實作:https://github.com/antinucleon/cxxnet

繼續閱讀