LSTM
使用一個特殊的存儲記憶單元可以改善
RNN
的梯度消失問題,是以在許多自然語言處理任務中它比
RNN
有更好的性能。
LSTM
單元的基本結構如下圖所示。
它由輸入門 i t i_{t} it,忘記門 f t f_{t} ft,輸出門 o t o_{t} ot,以及一個記憶單元 c t c_{t} ct組成。
給定一個文本序列 x = { x 1 , x 2 , … , x n } , x t x=\left\{x_{1}, x_{2}, \ldots, x_{n}\right\}, x_{t} x={x1,x2,…,xn},xt 表示目前時間步 t t t的輸入, h t − 1 h_{t-1} ht−1表示上一步
LSTM
的輸出。
LSTM
通過門更新記憶單元狀态,添加或删除資訊以保留與任務相關的内容。 i t i_{t} it乘以候選值 u t u_{t} ut決定了添加到記憶單元的新的輸入資訊。 f t f_{t} ft 乘以 h t − 1 h_{t-1} ht−1 決定從記憶單元狀态中移除的已有資訊。輸出門 o t o_{t} ot決定從記憶單元狀态最終的輸出資訊。
輸入門:
i t = σ ( W ( i ) x t + U ( i ) h t − 1 + b ( i ) ) i_{t}=\sigma\left(W^{(i)} x_{t}+U^{(i)} h_{t-1}+b^{(i)}\right) it=σ(W(i)xt+U(i)ht−1+b(i))
忘記門:
f t = σ ( W ( f ) x t + U ( f ) h t − 1 + b ( f ) ) f_{t}=\sigma\left(W^{(f)} x_{t}+U^{(f)} h_{t-1}+b^{(f)}\right) ft=σ(W(f)xt+U(f)ht−1+b(f))
輸出門:
o t = σ ( W ( o ) x t + U ( o ) h t − 1 + b ( o ) ) o_{t}=\sigma\left(W^{(o)} x_{t}+U^{(o)} h_{t-1}+b^{(o)}\right) ot=σ(W(o)xt+U(o)ht−1+b(o))
記憶單元候選值:
u t = tanh ( W ( u ) x t + U ( u ) h t − 1 + b ( u ) ) u_{t}=\tanh \left(W^{(u)} x_{t}+U^{(u)} h_{t-1}+b^{(u)}\right) ut=tanh(W(u)xt+U(u)ht−1+b(u))
記憶單元狀态更新:
c t = i t ⊙ u t + f t ⊙ c t − 1 c_{t}=i_{t} \odot u_{t}+f_{t} \odot c_{t-1} ct=it⊙ut+ft⊙ct−1
輸出:
h t = o t ⊙ tanh ( c t ) h_{t}=o_{t} \odot \tanh \left(c_{t}\right) ht=ot⊙tanh(ct)