天天看點

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

最近面試被問到了LSTM為什麼能夠解決long-range dependency的問題,回答這個問題實際上需要把BPTT公式寫出來,在這篇博文中我們進行了部分推導

習翔宇:RNN Part 3-RNN中的BPTT算法和梯度消失問題​zhuanlan.zhihu.com

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

但是不夠系統化,本篇博文将完全對RNN的BPTT以及LSTM的BPTT進行推導,并對long-range dependency問題進行分析

1. RNN的BPTT

假設RNN的基本方程如下所示

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析
BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

損失函數定義如下:

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

對于一個輸入序列

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

,其整體損失函數為

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

我們接下來分别對

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

進行求導

首先對

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

進行求導,這個比較簡單

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

然後對

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

進行求導

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

如下公式可知

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析
BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

的計算涉及到

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

,而

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

的計算也涉及到

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

,同樣

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

的計算涉及到

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

,而

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

的計算也涉及到

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

,以此類推,是以需要回溯到t時刻之前的所有時刻,我們需要對公示(6)中的第三項

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

進行展開,下面我們單獨對其進行展開如下所示:

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

同樣的道理,公示(8)中的第一項

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

的計算如下所示

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

将其帶入到公示(8)中即可得到

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

這樣我們把公式(6)中的第三項就展開了,現在帶入公式(6)中即可得到:

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

按照同樣的方式,我們對

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

進行求導

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

2. RNN梯度消失分析

在上面的推導中,我們對

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

部分的推導公式(11),(12)可以看到,在計算

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

時刻的損失産生的梯度時,必須回溯之前所有時刻

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

的資訊,并且存在連乘項

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

,根據公式(1)我們可以計算

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

sigmoid函數的導數大家都很熟悉了,處于

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

之間,那麼會有以下兩種情況:

  1. BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析
    > 4的時候,那麼
    BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析
    ,此時如果
    BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析
    距離過大,會導緻連乘項過多,産生梯度爆炸,趨近于無窮
  2. BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析
    <4的時候,那麼
    BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析
    ,此時如果
    BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析
    距離過大,會導緻連乘項過多,産生梯度消失,趨近于0

是以當輸入序列過長的時候,在求取一個比較遠的時刻

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

的梯度時,需要回溯到前面的所有時刻的資訊,由于連乘項的存在,導緻前面時刻的資訊會缺失,這就是RNN中的梯度消失問題,也是所謂的long-range dependency問題(這樣劃一個約等号會不會太草率?);

梯度爆炸問題容易解決,例如采用clip的方式即可。但是梯度消失的問題比較難以解決,我們下面介紹LSTM為什麼能夠緩解梯度消失問題

3. LSTM BPTT推導及梯度消失分析

LSTM的公式如下所示

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析
BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

其中

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

可以看作之前RNN中的

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

,我們将

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

的計算公式展開如下所示:

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

那麼需要連乘的部分計算可得:

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

從之前的

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

變成了sigmoid函數,範圍在[0,1]之間,在實際參數更新中,可以通過控制使得其接近于1,是以多次連乘依然不會産生梯度消失,在

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

距離較大的情況下,依然能夠較好的利用

BP算法和RNN_RNN/LSTM BPTT詳細推導以及梯度消失問題分析

時刻的資訊進行梯度計算。

4. 一些思考

本文由兩個問題後續進行提升:

  1. LSTM部分的推導并不十分嚴謹,在RNN BPTT的基礎上進行了類比
  2. 梯度消失問題以及long-range dependency問題的定義需要明确,本文進行了約等于,是否準确還有待商榷