天天看點

RNN的梯度消失和梯度爆炸RNN 梯度消失&梯度爆炸

文章目錄

  • RNN 梯度消失&梯度爆炸
    • 1. 深層網絡角度解釋梯度消失和梯度爆炸
    • 2. 激活函數角度解釋梯度消失和梯度爆炸
    • 3. RNN中的梯度消失和CNN的梯度消失有差別
    • 4. 梯度消失、爆炸的解決方案
      • 4.1 梯度爆炸的解決方案
      • 4.2 梯度消失的解決方案
        • 4.2.1 選擇relu、leakrelu、elu等激活函數
        • 4.2.2 使用Batchnorm(batch normalization,簡稱BN):
        • 4.2.3 殘差結構:
        • 4.2.4 LSTM:
    • 5. 參考

RNN 梯度消失&梯度爆炸

參考:https://zhuanlan.zhihu.com/p/33006526?from_voters_page=true

梯度消失和梯度爆炸本質是同一種情況。梯度消失經常出現的原因:一是使用深層網絡;二是采用不合适的損失函數,如Sigmoid。梯度爆炸一般出現的場景:一是深層網絡;二是權值初始化太大。

1. 深層網絡角度解釋梯度消失和梯度爆炸

深層網絡由許多非線性層堆疊而來,每一層網絡激活後的輸出為 f i ( x ) f_{i}(x) fi​(x),其中 i i i為第 i i i層, x x x是第 i i i層的輸入,即第 i − 1 i-1 i−1層的輸出, f f f是激活函數,整個深層網絡可視為一個複合的非線性多元函數:

f i + 1 = f ( f i ∗ w i + b ) F ( x ) = f n ( . . . f 3 ( f 2 ( f 1 ( x ) ∗ w 1 + b ) ∗ w 2 + b ) . . . ) f_{i+1} = f(f_{i}*w_{i}+b) \\ F(x)=f_n(...f_3(f_2(f_1(x)*w_{1}+b)*w_{2}+b)...) fi+1​=f(fi​∗wi​+b)F(x)=fn​(...f3​(f2​(f1​(x)∗w1​+b)∗w2​+b)...)

目的是多元函數 F ( x ) F(x) F(x)完成輸入到輸出的映射,假設不同的輸入,輸出的最優解是g(x),則優化深層網絡就是為了找到合适的權值,滿足 L o s s = L ( g ( x ) , F ( x ) ) Loss=L(g(x),F(x)) Loss=L(g(x),F(x))取得極小值。

BP 算法基于梯度下降政策,以負梯度方向對參數進行調整,參數更新:

w ← w + Δ w Δ w = − α ∂ L o s s ∂ w Δ w 1 = ∂ L o s s ∂ w 2 = ∂ L o s s ∂ f n ∂ f n ∂ f n − 1 ∂ f n − 1 ∂ f n − 2 . . . ∂ f 3 ∂ f 2 ∂ f 2 ∂ w 2 ∂ f 2 ∂ w 2 = f 1 w\leftarrow w+\Delta{w} \\ \Delta{w} = -\alpha\frac{\partial{Loss}}{\partial{w}} \\ \Delta{w_1} = \frac{\partial Loss}{\partial w_2} = \frac{\partial Loss}{\partial f_n}\frac{\partial f_n}{\partial f_{n-1}}\frac{\partial f_{n-1}}{\partial f_{n-2}}... \frac{\partial f_{3}}{\partial f_{2}} \frac{\partial f_{2}}{\partial w_{2}} \\ \frac{\partial f_{2}}{\partial w_{2}}=f_1 w←w+ΔwΔw=−α∂w∂Loss​Δw1​=∂w2​∂Loss​=∂fn​∂Loss​∂fn−1​∂fn​​∂fn−2​∂fn−1​​...∂f2​∂f3​​∂w2​∂f2​​∂w2​∂f2​​=f1​

∂ f n ∂ f n − 1 \frac{\partial f_n}{\partial f_{n-1}} ∂fn−1​∂fn​​即對激活函數求導,如果此部分大于1,随着層數增加,梯度更新将以指數形式增加,即發生梯度爆炸;如果此部分小于1,随着層數增加,梯度更新将以指數形式衰減,即發生梯度消失。

梯度消失、爆炸,其根本原因在于反向傳播訓練法則,鍊式求導次數太多。

2. 激活函數角度解釋梯度消失和梯度爆炸

計算權值更新資訊,需要計算前層偏導資訊,是以激活函數選擇不合适,比如Sigmoid,梯度消失會更明顯。

S i g m o i d ( x ) = 1 1 + e − x S i g m o i d ′ ( x ) = e − x ( 1 + e − x ) 2 = S i g m o i d ( x ) ( 1 − S i g m o i d ( x ) ) Sigmoid(x) = \frac{1}{1+e^{-x}}\\ Sigmoid'(x) = \frac{e^{-x}}{(1+e^{-x})^2} =Sigmoid(x)(1-Sigmoid(x)) Sigmoid(x)=1+e−x1​Sigmoid′(x)=(1+e−x)2e−x​=Sigmoid(x)(1−Sigmoid(x))

如果使用sigmoid作為損失函數,其梯度是不可能超過0.25的,這樣經過鍊式求導之後,很容易發生梯度消失。

RNN的梯度消失和梯度爆炸RNN 梯度消失&梯度爆炸

tanh作為損失函數,它的導數圖如下,可以看出,tanh比sigmoid要好一些,但是它的導數仍然是小于1的。

t a n h ( x ) = s i n h ( x ) c o s h ( x ) = e x − e − x e x + e − x t a n h ′ ( x ) = 1 − ( e x − e − x ) 2 ( e x + e − x ) 2 = 1 − t a n h 2 ( x ) tanh(x) = \frac{sinh(x)}{cosh(x)}=\frac{e^{x}-e^{-x}}{e^{x}+e^{-x}}\\ tanh'(x) = 1-\frac{(e^{x}-e^{-x})^2}{(e^{x}+e^{-x})^2} = 1-tanh^2(x) tanh(x)=cosh(x)sinh(x)​=ex+e−xex−e−x​tanh′(x)=1−(ex+e−x)2(ex−e−x)2​=1−tanh2(x)

RNN的梯度消失和梯度爆炸RNN 梯度消失&梯度爆炸

由于sigmoid和tanh存在上述的缺點,是以relu激活函數成為了大多數神經網絡的預設選擇。relu函數的導數在正數部分是恒等于1,是以在深層網絡中就不存在梯度消失/爆炸的問題,每層網絡都可以得到相同的更新速度。另外計算友善,計算速度快,加速網絡的訓練。

但是relu也存在缺點:即在 x x x小于0時,導數為0,導緻一些神經元無法激活。輸出不是以0為中心的。是以引申出下面的leaky relu函數,但是實際上leaky relu使用的并不多。

R E L U ( x ) = m a x ( 0 , x ) L e a k y R E L U ( x ) = m a x ( 0.01 x , x ) RELU(x)=max(0,x)\\ Leaky RELU(x) =max(0.01x,x) RELU(x)=max(0,x)LeakyRELU(x)=max(0.01x,x)

RNN的梯度消失和梯度爆炸RNN 梯度消失&梯度爆炸
RNN的梯度消失和梯度爆炸RNN 梯度消失&梯度爆炸

3. RNN中的梯度消失和CNN的梯度消失有差別

RNN中的梯度消失/爆炸和MLP/CNN中的梯度消失/爆炸含義不同:MLP/CNN中不同的層有不同的參數,各是各的梯度;而 RNN 中同樣的權重在各個時間步共享,最終的梯度 g 等于各個時間步的梯度 g t g_t gt​ 的和。

  • RNN中的總的梯度不會消失。即便梯度越傳越弱,那也隻是遠距離的梯度消失,由于近距離的梯度不會消失,所有梯度之和并不會消失。RNN 所謂梯度消失的真正含義是,梯度被近距離梯度主導,導緻模型難以學到遠距離的依賴關系。
    RNN的梯度消失和梯度爆炸RNN 梯度消失&梯度爆炸

    RNN前向傳導過程:

    t = 1 s 1 = g ( U x 1 + W s 0 ) o 1 = f ( V g ( U x 1 + W s 0 ) ) t = 2 s 2 = g ( U x 2 + W s 1 ) o 2 = f ( V g ( U x 2 + W s 1 ) ) = f ( V g ( U x 2 + W g ( U x 1 + W s 0 ) ) ) t = 3 s 3 = g ( U x 3 + W s 2 ) o 3 = f ( V g ( U x 3 + W s 2 ) ) = f ( V g ( U x 3 + W g ( U x 2 + W ( U x 1 + W s 0 ) ) ) ) . . . t = m . . . L o s s = L ( o m , y ) ∂ L ∂ U = ∂ L ∂ o m ∂ o m ∂ s m ∂ s m ∂ U + ∂ L ∂ o m ∂ o m ∂ s m ∂ s m ∂ s m − 1 ∂ s m − 1 ∂ U + . . . + ∂ L ∂ o m ∂ o m ∂ s m ∂ s m ∂ s m − 1 . . . ∂ s 2 ∂ s 1 ∂ s 1 ∂ U = ∑ t = 1 m ∂ L ∂ o m ∂ o m ∂ s m ( ∏ j = t + 1 m ∂ s j ∂ s j − 1 ) ∂ s t ∂ U \begin{aligned} t &= 1 \\ s_1 &= g(Ux_1 + Ws_{0})\\ o_1 &= f(Vg(Ux_1 + Ws_{0}))\\ t &= 2 \\ s_2 &= g(Ux_2 + Ws_{1})\\ o_2 &= f(Vg(Ux_2 + Ws_{1})) =f(Vg(Ux_2 + Wg(Ux_1 + Ws_{0})))\\ t &= 3 \\ s_3 &= g(Ux_3 + Ws_{2})\\ o_3 &= f(Vg(Ux_3 + Ws_{2})) = f(Vg(Ux_3 + Wg(Ux_2 + W(Ux_1 + Ws_{0}))))\\ ...\\ t &= m \\ ...\\ Loss &= L(o_m,y)\\ \frac{\partial L}{\partial U} &= \frac{\partial L}{\partial o_m}\frac{\partial o_m}{\partial s_m}\frac{\partial s_m}{\partial U} + \frac{\partial L}{\partial o_m}\frac{\partial o_m}{\partial s_m}\frac{\partial s_m}{\partial s_{m-1}}\frac{\partial s_{m-1}}{\partial U}+...+ \frac{\partial L}{\partial o_m}\frac{\partial o_m}{\partial s_m}\frac{\partial s_m}{\partial s_{m-1}}...\frac{\partial s_{2}}{\partial s_{1}}\frac{\partial s_{1}}{\partial U}\\ &= \sum_{t=1}^{m}\frac{\partial L}{\partial o_m}\frac{\partial o_m}{\partial s_m}\left(\prod_{j=t+1}^{m}\frac{\partial s_{j}}{\partial s_{j-1}}\right)\frac{\partial s_t}{\partial U} \end{aligned} ts1​o1​ts2​o2​ts3​o3​...t...Loss∂U∂L​​=1=g(Ux1​+Ws0​)=f(Vg(Ux1​+Ws0​))=2=g(Ux2​+Ws1​)=f(Vg(Ux2​+Ws1​))=f(Vg(Ux2​+Wg(Ux1​+Ws0​)))=3=g(Ux3​+Ws2​)=f(Vg(Ux3​+Ws2​))=f(Vg(Ux3​+Wg(Ux2​+W(Ux1​+Ws0​))))=m=L(om​,y)=∂om​∂L​∂sm​∂om​​∂U∂sm​​+∂om​∂L​∂sm​∂om​​∂sm−1​∂sm​​∂U∂sm−1​​+...+∂om​∂L​∂sm​∂om​​∂sm−1​∂sm​​...∂s1​∂s2​​∂U∂s1​​=t=1∑m​∂om​∂L​∂sm​∂om​​(j=t+1∏m​∂sj−1​∂sj​​)∂U∂st​​​

    當激活函數為tanh, s t = t a n h ( U x t + W s t − 1 ) s_t = tanh(Ux_t + Ws_{t-1}) st​=tanh(Uxt​+Wst−1​),

    權值梯度:

    L o s s = ∑ t = 1 m ∂ L ∂ o m ∂ o m ∂ s m ( ∏ j = t + 1 m ∂ s j ∂ s j − 1 ) ∂ s t ∂ U = ∑ t = 1 m ∂ L ∂ o m ∂ o m ∂ s m ( ∏ j = t + 1 m t a n h ′ W ) ∂ s t ∂ U \begin{aligned} Loss &= \sum_{t=1}^{m}\frac{\partial L}{\partial o_m}\frac{\partial o_m}{\partial s_m}\left(\prod_{j=t+1}^{m}\frac{\partial s_{j}}{\partial s_{j-1}}\right)\frac{\partial s_t}{\partial U}\\ &= \sum_{t=1}^{m}\frac{\partial L}{\partial o_m}\frac{\partial o_m}{\partial s_m}\left(\prod_{j=t+1}^{m}tanh'W \right)\frac{\partial s_t}{\partial U} \end{aligned} Loss​=t=1∑m​∂om​∂L​∂sm​∂om​​(j=t+1∏m​∂sj−1​∂sj​​)∂U∂st​​=t=1∑m​∂om​∂L​∂sm​∂om​​(j=t+1∏m​tanh′W)∂U∂st​​​

  • MLP/CNN 的梯度消失:主要是随着網絡加深,淺層網絡的梯度越來越小,導緻參數無法更新疊代。

4. 梯度消失、爆炸的解決方案

在深度神經網絡中,往往是梯度消失出現的更多一些。

4.1 梯度爆炸的解決方案

  1. 梯度裁剪:主要思想是設定一個梯度剪切門檻值,然後更新梯度的時候,如果梯度超過這個門檻值,那麼就将其強制限制在這個範圍之内。這可以防止梯度爆炸。
  2. 權值正則化(weithts regularization):正則化是通過對網絡權重做正則限制過拟合,如下正則項在損失函數中的形式:

    L o s s = ( y − W T x ) 2 + α ∣ ∣ W ∣ ∣ 2 Loss = (y-W^Tx)^2+\alpha||W||^2 Loss=(y−WTx)2+α∣∣W∣∣2

    常見的是L1正則和L2正則,在各個深度架構中都有相應的API可以使用正則化。

    其中, α \alpha α是指正則項系數,是以,如果發生梯度爆炸,權值的範數就會變的非常大,通過正則化項,可以部分限制梯度爆炸的發生。

4.2 梯度消失的解決方案

4.2.1 選擇relu、leakrelu、elu等激活函數

  • relu函數的導數在正數部分是恒等于1的,是以在深層網絡中不會導緻梯度消失和爆炸的問題。relu優點:解決了梯度消失、爆炸的問題,計算速度快,加速網絡訓練。relu缺點:導數的負數部分恒為0,會導緻一些神經元無法激活(可通過設定國小習率部分解決),輸出不是以0為中心的。
  • leakrelu就是為了解決relu的0區間帶來的影響。數學表達: l e a k r e l u = m a x ( x ∗ k , x ) leakrelu=max(x*k,x) leakrelu=max(x∗k,x)

    其中k是leak系數,一般選擇0.1或者0.2,或者通過學習而來。leakrelu解決了0區間帶來的影響,而且包含了relu的所有優點。

RNN的梯度消失和梯度爆炸RNN 梯度消失&梯度爆炸
  • elu也是為了解決relu的0區間帶來的影響,其數學表達為:

    e l u ( x ) = { x , if  x > 0 α ( e x − 1 ) , otherwise elu(x) = \begin{cases} x, & \text{if }x>0\\ \alpha(e^x-1), & \text{otherwise} \end{cases} elu(x)={x,α(ex−1),​if x>0otherwise​

    但是elu相對于leakrelu來說,計算要更耗時間一些。

4.2.2 使用Batchnorm(batch normalization,簡稱BN):

目前已經被廣泛的應用到了各大網絡中,具有加速網絡收斂速度,提升訓練穩定性的效果,Batchnorm本質上是解決反向傳播過程中的梯度問題。通過規範化操作将輸出信号x規範化到均值為0,方差為1保證網絡的穩定性。 具體來說就是反向傳播中,經過每一層的梯度會乘以該層的權重,舉個簡單例子: 正向傳播中 f 3 = f 2 ( w T x + b ) f_3=f_2(w^Tx+b) f3​=f2​(wTx+b) ,那麼反向傳播中, ∂ f 2 ∂ x = ∂ f 2 ∂ f 1 w \frac{\partial f_2}{\partial x}=\frac{\partial f_2}{\partial f_1}w ∂x∂f2​​=∂f1​∂f2​​w, 反向傳播式子中有 w w w 的存在,是以 w w w 的大小影響了梯度的消失和爆炸,batchnorm就是通過對每一層的輸出做scale和shift的方法,通過一定的規範化手段,把每層神經網絡任意神經元這個輸入值的分布強行拉回到接近均值為0方差為1的标準正太分布,即嚴重偏離的分布強制拉回比較标準的分布,這樣使得激活輸入值落在非線性函數對輸入比較敏感的區域,這樣輸入的小變化就會導緻損失函數較大的變化,使得讓梯度變大,避免梯度消失問題産生,而且梯度變大意味着學習收斂速度快,能大大加快訓練速度。

4.2.3 殘差結構:

殘差單元裡的shortcut(捷徑)部分可以保證在反向傳播中梯度不會消失。

∂ L o s s ∂ x l = ∂ L o s s ∂ x L ∂ x L ∂ x l = ∂ L o s s ∂ x L ( 1 + ∂ ∂ x L ∑ i = l L − 1 F ( x i , W i ) ) \frac{\partial Loss}{\partial x_l}=\frac{\partial Loss}{\partial x_L}\frac{\partial x_L}{\partial x_l}=\frac{\partial Loss}{\partial x_L}\left(1+\frac{\partial }{\partial x_L}\sum_{i=l}^{L-1}F(x_i,W_i)\right) ∂xl​∂Loss​=∂xL​∂Loss​∂xl​∂xL​​=∂xL​∂Loss​(1+∂xL​∂​i=l∑L−1​F(xi​,Wi​))

式子的第一個因子 ∂ L o s s ∂ x L \frac{\partial Loss}{\partial x_L} ∂xL​∂Loss​表示的損失函數到達 L 層的梯度,小括号中的1表明短路機制可以無損地傳播梯度,而另外一項殘差梯度則需要經過帶有weights的層,梯度不是直接傳遞過來的。殘差梯度不會那麼巧全為-1,而且就算其比較小,有1的存在也不會導緻梯度消失。是以殘差學習會更容易。

RNN的梯度消失和梯度爆炸RNN 梯度消失&梯度爆炸

4.2.4 LSTM:

使用LSTM(long-short term memory networks,長短期記憶網絡),就不那麼容易發生梯度消失,主要原因在于LSTM内部複雜的“門”(gates),如下圖,LSTM通過它内部的“門”可以在更新的時候“記住”前幾次訓練的“殘留記憶”,是以,經常用于生成文本中。

RNN的梯度消失和梯度爆炸RNN 梯度消失&梯度爆炸

5. 參考

RNN

  • https://zybuluo.com/hanbingtao/note/541458
  • https://colab.research.google.com/drive/1Zfvt9Vfs3PrJwSDF8jMvomz7CzU36RXk

LSTM

  • http://colah.github.io/posts/2015-08-Understanding-LSTMs/
  • https://www.youtube.com/watch?v=9zhrxE5PQgY&feature=youtu.be
  • https://towardsdatascience.com/illustrated-guide-to-lstms-and-gru-s-a-stepby-step-explanation-44e9eb85bf21
  • https://medium.com/datadriveninvestor/how-do-lstm-networks-solve-theproblem-of-vanishing-gradients-a6784971a577

繼續閱讀