天天看點

LSTM網絡(Long Short-Term Memory )

本文基于前兩篇 1. 多層感覺機及其BP算法(Multi-Layer Perceptron) 與 2. 遞歸神經網絡(Recurrent Neural Networks,RNN)

RNN 有一個緻命的缺陷,傳統的 MLP 也有這個缺陷,看這個缺陷之前,先祭出 RNN 的 反向傳導公式與 MLP 的反向傳導公式:

\[RNN : \ \delta_h^t = f\'(a_h^t) \left (\sum_k\delta_k^tw_{hk} + \sum_{h\'} \delta^{t+1}_{h\'}w_{hh\'}   \right )\]

\[MLP : \ \delta_h =   f\'(a_h) \sum_{h\'=1}^{h_{l+1}} w_{hh\'}\delta_{h\'}\]

注意,殘差在時間次元上反向傳遞時,每經過一個時刻,就會導緻信号的大幅度衰減,為啥呢,就是因為這個非線性激活函數 $f$ ,一般這個函數的形狀如下圖:

LSTM網絡(Long Short-Term Memory )

如上圖所示,激活函數 $f$ 在在紅線以外的斜度變化很小,是以函數 $f$ 的導數 $f\'$ 取值很小,而經過以上列出的殘差反向傳遞公式可以得出,每經過一個時刻,衰減 $f\'$ 的數量級,是以經過多個時刻會導緻時間次元上梯度呈指數級的衰減,即此刻的回報信号不能影響太遙遠的過去 。在 MLP 中,如果網絡太深,這種梯度衰減會導緻網絡的前幾層的殘差趨近于 0 ,這意味着前面的隐藏層中的神經元學習速度要慢于後面的隐藏層。無論 RNN 還是 MLP ,對參數的導數都是這種形式(RNN需要在時間次元上求和):

\[\frac{\partial O}{\partial w_{ij}} = \frac{\partial O}{\partial a_{j}} \frac{\partial a_j}{\partial w_{ij}} = \delta_jb_i\]

殘差衰減的太小導緻參數的導數太小 ,進而梯度下降法中前幾層的參數隻有微乎其微的變化,對于深層的 MLP 由于梯度衰減導緻效果不如淺層的網絡,對于 RNN 就會導緻不能處理長期依賴的問題,雖然 RNN 理論上可以處理任意長度的序列,但實習應用中,RNN 很難處理長度超過 10 的序列。這種現象叫做 gradient vanishing/exploding 。下圖形象的描繪了這種現象:

LSTM網絡(Long Short-Term Memory )

對于 $t=1$ 的輸入,随着時間的推移,對于 $t >1$ 時刻的産生的影響會越來越小,由圖中的顔色的深淺代表信号的大小。這種衰減會導緻 RNN 無法處理長期依賴,舉個例子,當有一句話“I grew up in France … I speak fluent French.”  在預測該人會将一口流利的            語時,會依賴之前他的長大的環境,而序列中兩個詞語的間隔太大,這便是所說的長期依賴問題。  

LSTM網絡(Long Short-Term Memory )

對于長期以來問題,反向傳播時,梯度也會呈指數倍數的衰減,這種衰減現象導緻 RNN 無法處理長期依賴,為了克服 RNN 的這種缺陷,學者們研究了衆多方法,其中 Long Short-Term Memory 表現最為出色。使用 LSTM 子產品後,當誤差從輸出層反向傳播回來時,可以使用子產品的記憶元記下來。是以 LSTM 可以記住比較長時間内的資訊。

初始的 LSTM (Hochreiter and Schmidhuber ,1997)網絡結構類似于 RNN ,隻是把 RNN 的隐層換成了存儲塊(memeory block),如下圖左所示, memory block 中用記憶單元 (memory cell)來儲存資訊(類似于 RNN 中的隐藏節點),,每個存儲塊包含一個或多個memory cell ,如下圖左中間的 “$\oslash$” 節點如下圖所示,藍色虛線為一條遞歸自連接配接的權值為 1 的邊,保證梯度沿時間傳播時不會損失,在時刻 $t$  的輸入如下圖的 $g^t$ 所示,除接受本時刻的輸入 $x^t$ 外,還接受上一時刻的輸出 $h^{t-1}$ ,并且經過非線性激活函數 $\sigma$ ,LSTM 并不是接納所有輸入 $g^t$ ,而是在網絡中加入兩個門,輸入門(input gate)、輸出門(output gate), 門的節點數目與 memory cell 一一對應, input gate 如下圖的 $i^t$ 所示,跟輸入層一樣,接受 $x^t$ 與 $h^{t-1}$ ,經過  $\sigma$ 後産生一個 0-1 向量(次元即為 memory cell 或者 input gate 的次元),0 代表關閉 、1 代表開啟,這樣來對輸入進行控制,下圖左中的 “$\prod$ ” 表示 input gate 的輸出  $i^t$ 與本時刻輸入 $g^t$ 的輸出逐元素相乘,即 input gate 會對輸入進行過濾 ,然後存放到 memory cell 裡,現在memory cell 裡既有上一時刻 $t-1$ 的狀态,又添加了時刻 $t$ 的狀态, 即

\[s^t = g^t \odot i^t + s^{t-1}\]

memory cell 有一個循環自連接配接的權值為 1 的邊,這樣 memory cell state 中梯度沿時間傳播時不會導緻不會 vanishing 或者 exploding ,output gate 類似于 input gate 會産生一個 0-1 向量來控制 memory cell 到輸出層的輸出。即

\[ v^t = s^t \odot o^t  \]

LSTM網絡(Long Short-Term Memory )

後來為了增強 LSTM 的處理能力, Gers et al. [2000] 引入了 forget gate, LSTM 的網絡結構變成了如上圖右所示,也就是說 forget gate 取代了之前權值為 1 的邊,經過這樣的改進,memory cell 可以遺忘之前的内容,隻需将 memory cell 中的内容與 forget gate 逐元素相乘即可, forget gate  與 input/output gate 一樣,接受  $x^t$ 與 $h^{t-1}$ 作為輸入,現在的 LSTM memory cell 的更新公式為:

\[s^t = g^t \odot i^t + f^t \odot s^{t-1}\]

Gers & Schmidhuber [2000] 在以上結構的基礎上又提出了 peephole connections ,将 $t-1$ 時刻沒有經過 output gate 處理過的 memory cell 狀态送到時刻   $t$ 作為 input gate 和 output gate 的輸入,即三個門的輸入增加了了  $s^ {t-1}$ ,現在流行的網絡結構如下圖所示:

LSTM網絡(Long Short-Term Memory )

三個門協作使得  LSTM 存儲塊可以存取長期資訊,比如說隻要輸入門保持關閉,記憶單元的資訊就不會被之後時刻的輸入所覆寫。下圖形象的描述了這個過程,在 Hidden Layer 中每個節點都是一個 memeory block ,每個 memeory block 的包含三個門,左邊為 forget gate ,下邊尾 input gate ,上邊為 output gate ,門有打開關閉兩種狀态,分别由 "$\bigcirc $" 與 "$-$" 來表示。可見對于時刻 1 的輸入,隻要之後時刻的 input gate 保持關閉,forget gate 保持打開,便可以在不影響 memory cell 的情況下随時開啟 output gate 來獲得 memory cell 的内容。對于梯度反向傳播時,同樣可以通過這種方式來保持殘差不會過度衰減。

LSTM網絡(Long Short-Term Memory )

接下來本文所涉及的将是詳細 LSTM 的 BP 過程,網絡結構采用的是 Gers & Schmidhuber [2000]所提出的 LSTM 結構,值得注意的是,這裡對 memory cell 的輸出增加了激活函數 $h$ , 之前的 $h$ 可以了解為線性的,這裡先聲明一些符号表示: $w_{ij}$ 表示 單元 $i$ 到單元 $j$ 的權值,$a_j^t$ 表示時刻 $t$ 單元  $j$ 的輸入,$b_j^t = f(a_j^t)$ 表示對單元 $j$ 的輸入做非線性映射,$\iota$  、 $\phi$  、 $\omega$ 分别代表 input gate 、forget gate、 output gate ,$C$ 用來表示 memroy cell 的數量,  $s^t_c$ 表示 memeory cell $c$ 在時刻  $t$ 的狀态, $f$ 表示門的激活函數(通常為 $sigmod$ 函數), $g$ 與 $h$ 分别表示 memory cell 輸入與輸出的激活函數,用 $I$ 表示輸入層大小, $H$ 表示隐層 memory cell 的大小(其實 $H = C$,這裡隻是為了友善表示,因為 memory cell 的輸出   $b_h^t$ 會往下個時刻傳輸,其權值可表示為 $w_{h.}$ , memrory cell 本身的權值可用  $w_ {c.}$ 來表示) , $K$ 表示輸出層的大小。 待序列為 $t = 1...T$ ,時刻 $t$ 的輸入為 $x^t$ ,注意 $b^0 = 0$ , 殘差 $\delta ^{T+1} = 0$ 。

  • forget gate : 在 LSTM 的 memory block 中,隻有上一時刻 memory cell 的輸出 $ b_h^t$ 會傳送到本單元 ,其他資料比如 memory cell state 或者 memory cell  input 等隻在單元内部可見,forget gate 是用來控制上個時刻的 memory cell state 即 $s^{t-1}$ :

\[a^t_{\phi } = \sum_iw_{i \phi } x_i^t + \sum_hw_{h \phi}b_{h}^{t-1}+ \sum_cw_{c\phi}s_c^{t-1} \]

\[b_{\phi }^t = f(a_{\phi}^t)\]

  • input gate : 這個門控制目前時刻 memory cell state 的輸入:

\[a^t_{\iota } = \sum_iw_{i \iota } x_i^t + \sum_hw_{h \iota}b_{h}^{t-1}+ \sum_cw_{c\iota}s_c^{t-1} \]

\[b_{\iota }^t = f(a_{\iota}^t)\]

  • memory cell : 對于時刻 $t-1 \rightarrow  t$ , memroy cell 的資訊是這樣變化的 ,首先對 $t-1$  時刻 memory cell 的狀态用 forget gate 進行過濾($b_{\phi}^t s_c^{t-1}$),看要遺忘或者儲存哪些資訊,然後擷取現在時刻 $t$ 的輸入資訊($g(a_c^t)$),用 input gate 進行過濾 ($b_{\iota }^tg(a_c^t)$),過濾完後相加就完成了$t-1 \rightarrow  t$ 時刻的 memory cell 狀态的轉變 :

\[a^t_c = \sum_i w_{ic} x_i^t + \sum_h w_{hc}b_{h}^{t-1} \]

\[s_c^t = b_{\phi}^t s_c^{t-1} + b_{\iota }^tg(a_c^t)\]

  • output gate : 這個門會控制 cell state 的輸出:

\[a^t_{\omega } = \sum_iw_{i \omega } x_i^t + \sum_hw_{h \omega }b_{h}^{t-1}+ \sum_cw_{c\omega }s_c^{t} \]

\[b_{\omega }^t = f(a_{\omega }^t)\]

  • memory cell output : 計算 memory cell 的輸出 ,由 output gate 控制,這個輸出也會作為下一時刻整個 memory block 的輸入(類似于 RNN 的隐層傳遞)

\[b_c^t = b_{\omega}^t h(s_c^t)\]

接下來便是殘差的反向傳導,對于輸出層,同 RNN 一般是 $softmax$ 或者 $logistic$ ,這裡首先定義:

\[\epsilon_c^t=\frac{\partial O}{\partial b_c^t}=\sum_k\frac{\partial O}{\partial a_k^t} \frac{\partial a_k^t}{\partial b_c^t}+\sum_{h}\frac{\partial O}{\partial a_h^t} \frac{\partial a_h^t}{\partial b_c^t}=\sum_{k} w_{ck}\delta_k^t+\sum_hw_{ch}\delta_h^{t+1} \] 

接下來,殘差傳導至 output gate :

\[\delta_\omega^t=\frac{\partial O}{\partial a_\omega^t}=\sum_c \frac{\partial O}{\partial b_c^t}\frac{\partial b_c^t}{\partial b_\omega^t}\frac{\partial b_\omega^t}{\partial a_\omega^t} =f\'(a_\omega^t)\sum_c \epsilon_c^t h(s_c^t) \]

現在再定義一個輔助變量:

\[\epsilon_s^t=\frac{\partial \mathcal{L}}{\partial s_c^t}

=\frac{\partial O}{\partial b_c^t} \frac{\partial b_c^t}{\partial h(s_c^t)} \frac{\partial h(s_c^t)}{\partial s_c^t}

+\frac{\partial O}{\partial s_c^{t+1}} \frac{\partial s_c^{t+1}}{\partial s_c^t}

+\frac{\partial O}{\partial a_\omega^t} \frac{\partial a_\omega^t}{\partial s_c^t}

+\frac{\partial O}{\partial a_\iota^t} \frac{\partial a_\iota^t}{\partial s_c^t}

+\frac{\partial O}{\partial a_\phi^t} \frac{\partial a_\phi^t}{\partial s_c^t} \Rightarrow\]

\[\epsilon_s^t=b_w^th\'(s_c^t)\epsilon_c^t+b_\phi^{t+1}\epsilon_s^{t+1}+w_{c\omega}\delta_\omega^t+w_{c\iota}\delta_\iota^{t+1} +w_{c\phi}\delta_\phi^{t+1}\]

這就是 bp 中最複雜的公式了,依次解釋下各項。首先,看memory block的圖,檢視該單元指向輸出單元的所有路徑,沒有一條不同的路徑就代表一項;然後運用鍊式法則展開每個路徑;就得到後向傳播中該單元的梯度$\delta$。這個輔助變量中可以看到後三項來自于cell state 對三個 gate 的監督,即 peephole ,是以若不采用 peephole 的方式就可以省略。第二項來自于下一時刻的狀态誤差,其實是 forget gate 對目前狀态的調節作用。

接下來誤差傳播到 memory cell :

\[\delta_c^t =\frac{\partial O}{\partial a_c^t}=\frac{\partial O}{\partial s_c^t}\frac{\partial s_c^t}{\partial g(a_c^t)}\frac{\partial g(a_c^t)}{\partial a_c^t}=\epsilon_c^t b_\iota^t g\'(a_c^t)\]

最後分别傳導至 forget gate $\phi$ 與 輸入門 $\iota$:

\[\delta_\phi^t =\frac{\partial O}{\partial a_\phi^t}=\sum_c\frac{\partial O}{\partial s_c^t}\frac{\partial s_c^t}{\partial b_\phi^t}\frac{\partial b_\phi^t}{\partial a_\phi^t}=f\'(a_\phi^t)\sum_c s_c^{t-1}\epsilon_s^t \]

\[\delta_\iota^t =\frac{\partial O}{\partial a_\iota^t}=\sum_c\frac{\partial O}{\partial s_c^t}\frac{\partial s_c^t}{\partial b_\iota^t}\frac{\partial b_\iota^t}{\partial a_\iota^t}=f\'(a_\iota^t)\sum_c g(a_c^{t-1})\epsilon_s^t\]

 殘差傳導完成後,直接用殘差對權重 $w_{ij}$ 進行求導即可 (這裡 $b_i^t$ 可代表輸入 $x_i^t$、$b_h^{t-1}$、$s_c^{t-1}$):

\[\frac{\partial O}{\partial w_{ij}} = \sum_t \frac{\partial O}{\partial a_j^t}\frac{\partial a_j^t}{\partial w_{ij}} = \sum_t \delta_j^tb_i^t\]

參考:http://colah.github.io/posts/2015-08-Understanding-LSTMs/

     Supervised Sequence Labelling with Recurrent Neural Networks

     http://ethancao.cn/2015/12/07/learning-LSTM.html 

LSTM網絡(Long Short-Term Memory )