天天看點

Seq2Seq的一些概念

Recurrent Neural Network

RNN 又叫做遞歸神經網絡或者循環神經網絡,它擅長對序列資料進行模組化處理,如時間序列資料,是指在不同時間點上收集的資料,這類資料反映了某一事物、現象在随時間的變化狀态或程度,當然這是時間,也可以是文本或圖像序列,總的來說,序列資料存在着一個特點——後面的資料跟前面的資料有關系

為什麼需要 RNN ?

神經網絡結構隻能單獨的處理一個個的輸入,前一個輸入與後一個輸入是完成沒有關系的,但是某些任務需要更好的處理序列資訊,即前面一個輸入和後面一個輸入是要有關系的,通俗點來說就是後一個輸入需要記憶前面一個輸入的資訊

比如,當我們在了解一句話時,孤立的了解單個詞是沒有意義的,隻有将上下詞聯系起來的整個序列才具有意義;當我們處理視訊時,也不能單獨分析每一帖,需要分析這些幀連接配接起來的整個序列

為了解決這一類問題,能夠更好的處理序列資訊,RNN 模型就應運而生,那麼 RNN 又是怎樣實作這樣的功能呢?

RNN 的結構

RNN 主要對序列資料進行序列處理,其基本結構如下圖所示:

Seq2Seq的一些概念

上圖是 RNN 的結構示意圖,每個箭頭表示着一次變換,也就是說箭頭帶有權值,左側是折疊起來的樣子,右側是展開的樣子,左側 A 旁邊的箭頭展現着結構中的 “循環” 概念。

在右側展開結構中我們可以看到,在 x 0 x_0 x0​ 作為輸入時,該單元的輸出分為二個方向,向上的 h o h_o ho​ 表示的是其作為一個輸出,向右箭頭表示的是其另一個輸出作為下一個單元的輸入,以此達到與下一個單元之間保持着某個聯系,即記憶功能

為了更好的了解,我們看下圖:

Seq2Seq的一些概念

簡單點來說就是:當在 x t x_t xt​ 時刻時,該單元的輸入就分為二個: S t − 1 S_{t-1} St−1​ 、 x t x_t xt​, 輸入也分為二個: S t S_{t} St​ 、 O t O_t Ot​

  • S t − 1 S_{t-1} St−1​ 表示的是 x t − 1 x_{t-1} xt−1​ 時刻的一個輸出
  • x t x_t xt​ 表示本時刻的一個輸入
  • S t S_{t} St​ 表示 x t x_t xt​ 時刻的一個輸出,将作為下一時刻的一個輸入
  • O t O_t Ot​ 表示 x t x_t xt​ 時刻的輸出

我們可以用下面的公式來表示 RNN 的計算方式:

Seq2Seq的一些概念
  • 上圖同樣與展現出了 RNN 的另一個特點:權值共享,其中 U 是完全相同的, W、V也是一樣的

那麼我們再來看看 隐藏層 S 中究竟發生了怎樣的變化

Seq2Seq的一些概念

我們可以看到 h t − 1 h_{t-1} ht−1​ 和 x t x_t xt​ 之間實際上是做了一個 ocncatenate 操作,然後再經過激活函數最終形成了一個輸出,值得注意的是它的一個次元變化

Bidirectional RNNs 雙向循環神經網絡

Seq2Seq的一些概念

基本的 RNN 結構隻能從之前時間步驟中學習,但是有時我們卻需要從未來的時間步驟中學習表示,以便更好地了解上下文環境并消除歧義,通過接下來的列子,“He said, Teddy bears are on sale” and “He said, Teddy Roosevelt was a great President。在上面的兩句話中,當我們看到“Teddy”和前兩個詞“He said”的時候,我們有可能無法了解這個句子是指President還是Teddy bears。是以,為了解決這種歧義性,我們需要往前查找。這就是雙向RNN所能實作的。

如圖所求,雙向 RNN 有二種類型的連接配接,一種是前向的(Foward RNN),這有助于我們從之前的表示中學習, 另一種是後向的(Backward RNN),這有助于我們從之後的表示中學習

正向傳播分為二個步驟:

  1. 我們先從左向右移動,從初始時間步驟開始計算,一直持續到到達最終時間步驟為止
  2. 再從右向左移動,從最後一個時間步驟開始計算,一直持續到到達最終時間步驟為止
  • 一般來說是從前往向計算,再從後往前計算,計算過程互相獨立,互不幹擾

計算預測輸出值就變成了:

y ^ < t > = g ( W y [ a → < t > , a ← < t > ] ) \hat{y}^{<t>}= g(W_y[\overrightarrow{a}^{<t>},\overleftarrow{a}^{<t>}]) y^​<t>=g(Wy​[a

<t>,a

<t>])

a → < t > \overrightarrow{a}^{<t>} a

<t>表示 Forward RNN 的激活函數, a ← < t > \overleftarrow{a}^{<t>} a

<t> 表示 Backward RNN 的激活函數,箭頭方向表示的傳遞方向

梯度消失和梯度爆炸

誤差梯度在網絡訓練中用來得到網絡參數的方向和步幅,在正确的方向下以合适的步幅更新網絡參數。

梯度爆炸:在遞歸神經網絡中,誤差梯度會在更新中累積得到一個非常大的梯度,這樣的梯度會大幅更新網絡參數,導緻網絡的不穩定,在極端情況下,權值會變得非常的大以至于結果會溢出(NaN值、無窮或非數值),當梯度爆炸發生時,網絡層之間反複乘以大于1.0的值使得梯度值成倍增長

梯度更新:如果誤差梯度在更新中累積得到一個非常小的梯度,這也就意味着權值無法更新,最終導緻訓練失敗

利用公式分析原因

經典 RNN 的結構如下圖所求:

Seq2Seq的一些概念

關于向前傳播:

假設我們的時間序列隻有三段, S 0 S_0 S0​ 為定值,神經元沒有激活函數(便于分析)就可獲得各個時間段的狀态和輸出:

t = 1   時 刻 S 1 = U X 1 + W S 0 + b 1 O 1 = V S 1 + b 2 \begin{aligned}&t = 1 \text{ }時刻\\&S_1 = UX_1 + WS_0 + b_1\\&O_1 = VS_1 + b_2\end{aligned}\\ ​t=1 時刻S1​=UX1​+WS0​+b1​O1​=VS1​+b2​​

t = 2   時 刻 S 2 = U X 2 + W S 1 + b 1 O 2 = V S 2 + b 2 \begin{aligned}&t = 2 \text{ }時刻\\&S_2 = UX_2 + WS_1 + b_1\\&O_2 = VS_2 + b_2\end{aligned}\\ ​t=2 時刻S2​=UX2​+WS1​+b1​O2​=VS2​+b2​​

t = 3   時 刻 S 3 = U X 3 + W S 2 + b 1 O 3 = V S 3 + b 2 \begin{aligned}&t = 3 \text{ }時刻\\&S_3 = UX_3 + WS_2 + b_1\\&O_3 = VS_3 + b_2\end{aligned}\\ ​t=3 時刻S3​=UX3​+WS2​+b1​O3​=VS3​+b2​​

損失函數采用交叉熵 L t = − O t ‾ l o g O t L_t=-\overline{O_t}logO_t Lt​=−Ot​​logOt​ ( O t O_t Ot​是 t 時刻的預測輸出, O t ‾ \overline{O_t} Ot​​是 t 時刻的真實輸出),那麼對于一次訓練任務中,損失函數為:

L = ∑ i = 1 T − O t ‾ l o g O t L = \sum_{i=1}^{T}-\overline{O_t}logO_t L=i=1∑T​−Ot​​logOt​

T 是序列總長度,上述公式為每一時刻損失值的累加

關于反射傳播:

我們隻對 t3 時時刻的 U、V、W 求偏導,由鍊式法則可得:

∂ L 3 ∂ V = ∂ L 3 ∂ O 3 ∂ O 3 ∂ V ∂ L 3 ∂ W = ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ W + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 2 ∂ S 2 ∂ W + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ S 1 ∂ S 1 ∂ W ∂ L 3 ∂ U = ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ U + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 2 ∂ S 2 ∂ U + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ S 1 ∂ S 1 ∂ U \begin{aligned}&\frac{\partial{L_3}}{\partial{V}} = \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{V}}\\&\frac{\partial{L_3}}{\partial{W}} = \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\frac{\partial{S_3}}{\partial{W}} + \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_2}}\frac{\partial{S_2}}{\partial{W}} + \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\frac{\partial{S_3}}{\partial{S_2}}\frac{\partial{S_2}}{\partial{S_1}}\frac{\partial{S_1}}{\partial{W}}\\&\frac{\partial{L_3}}{\partial{U}} = \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\frac{\partial{S_3}}{\partial{U}} + \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_2}}\frac{\partial{S_2}}{\partial{U}} + \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\frac{\partial{S_3}}{\partial{S_2}}\frac{\partial{S_2}}{\partial{S_1}}\frac{\partial{S_1}}{\partial{U}}\end{aligned} ​∂V∂L3​​=∂O3​∂L3​​∂V∂O3​​∂W∂L3​​=∂O3​∂L3​​∂S3​∂O3​​∂W∂S3​​+∂O3​∂L3​​∂S2​∂O3​​∂W∂S2​​+∂O3​∂L3​​∂S3​∂O3​​∂S2​∂S3​​∂S1​∂S2​​∂W∂S1​​∂U∂L3​​=∂O3​∂L3​​∂S3​∂O3​​∂U∂S3​​+∂O3​∂L3​​∂S2​∂O3​​∂U∂S2​​+∂O3​∂L3​​∂S3​∂O3​​∂S2​∂S3​​∂S1​∂S2​​∂U∂S1​​​

可以簡寫成:

∂ L 3 ∂ U = ∑ k = 0 3 ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S k ∂ S k ∂ U = ∑ k = 0 3 ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ( ∏ j = k − 1 3 ∂ S j ∂ S j − 1 ) ∂ S k ∂ U 任 意 時 刻 對 參 數 W 求 偏 導 的 公 式 : ∂ L 3 ∂ W = ∑ k = 0 t ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ( ∏ j = k − 1 t ∂ S j ∂ S j − 1 ) ∂ S k ∂ w \begin{aligned}&\frac{\partial{L_3}}{\partial{U}} = \sum_{k=0}^{3}\frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\frac{\partial{S_3}}{\partial{S_k}}\frac{\partial{S_k}}{\partial{U}}= \sum_{k=0}^{3}\frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\left( \prod_{j=k-1}^{3}\frac{\partial{S_j}}{\partial{S_{j-1}}} \right)\frac{\partial{S_k}}{\partial{U}}\\&任意時刻對參數 W 求偏導的公式:\\&\frac{\partial{L_3}}{\partial{W}} =\sum_{k=0}^{t}\frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\left( \prod_{j=k-1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}} \right)\frac{\partial{S_k}}{\partial{w}}\\\end{aligned} ​∂U∂L3​​=k=0∑3​∂O3​∂L3​​∂S3​∂O3​​∂Sk​∂S3​​∂U∂Sk​​=k=0∑3​∂O3​∂L3​​∂S3​∂O3​​⎝⎛​j=k−1∏3​∂Sj−1​∂Sj​​⎠⎞​∂U∂Sk​​任意時刻對參數W求偏導的公式:∂W∂L3​​=k=0∑t​∂O3​∂L3​​∂S3​∂O3​​⎝⎛​j=k−1∏t​∂Sj−1​∂Sj​​⎠⎞​∂w∂Sk​​​

由此可以看出 V 求偏導不存在依賴關系,而 W、U則随時間長度存在着長期的依賴關系,因為 S t S_t St​ 會随着時間序列向前傳播,而同時 S t S_t St​ 是 U、W 的函數

如果取其中的累乘出來,其中激活函數通常是:tanh = [0, 1] 則:

∏ j = k − 1 t ∂ S j ∂ S j − 1 = ∏ j = k − 1 t t a n h ′ W \prod_{j=k-1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}} = \prod_{j=k-1}^{t}tanh^{'}W j=k−1∏t​∂Sj−1​∂Sj​​=j=k−1∏t​tanh′W

Seq2Seq的一些概念
  • 由上圖可以看出 t a n h ′ ∈ [ 0 , 1 ] tanh^{'}\in [0, 1] tanh′∈[0,1] , 也就是說大部分都是 小于1的數在做累乘,假設 W 也是一個大于0小于1的值時,當 t 很大時, ∏ j = k − 1 t t a n h ′ W 公 式 中 的 ∏ j = k − 1 t t a n h ′ \prod_{j=k-1}^{t}tanh^{'}W 公式中的 \prod_{j=k-1}^{t}tanh^{'} ∏j=k−1t​tanh′W公式中的∏j=k−1t​tanh′ 部分會趨向于 0,這就是 RNN 中梯度消失的原因
  • 同理, ∏ j = k − 1 t t a n h ′ W 公 式 中 的 \prod_{j=k-1}^{t}tanh^{'}W 公式中的 ∏j=k−1t​tanh′W公式中的 W 參數很大時,結果就會趨于無窮,這就是産生 梯度爆炸 的原因

解決辦法

面對梯度爆炸的問題,我們可以看到梯度爆炸是因為 W 參數的值過大,而 W 值随着序列長度存在長期的依賴關系,因而我們可以設定一個上限值,一旦超過上限值,就等于我們的預設值,這樣就可以解決梯度爆炸的問題了

面對梯度消失的問題,梯度消失的原因是 ∏ j = k − 1 t ∂ S j ∂ S j − 1 \prod_{j=k-1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}} ∏j=k−1t​∂Sj−1​∂Sj​​ 求導而産生的,是以想要消除這種情況就需要在求上司的時候去掉就行了,那麼怎樣去掉呢,一般有二種方法:

  • 使 ∂ S j ∂ S j − 1 ≈ 1 \frac{\partial{S_j}}{\partial{S_{j-1}}} \approx 1 ∂Sj−1​∂Sj​​≈1,那麼怎樣達到這種目标呢?答案是換一種激活函數,我們來看一下 ReLu 作為激活函數的效果:
Seq2Seq的一些概念

​ 可以看到 ReLu 導數在定義域大于0的部分是恒等于1,這樣就可以解決梯度消失的問題了

  • 使 ∂ S j ∂ S j − 1 ≈ 0 \frac{\partial{S_j}}{\partial{S_{j-1}}} \approx 0 ∂Sj−1​∂Sj​​≈0,我們可以采用 LSTM 可以達到這樣的效果,那麼 LSTM 又是怎樣實作的呢,我們在下一篇文章中再來詳細解決

參考文獻:

[1]. https://www.jiqizhixin.com/articles/2019-01-17-7

聲明:

​ 以上内容為個人了解,若有錯誤,請各位大佬指出,以便大家多作交流!

繼續閱讀