天天看點

标準的LSTM網絡以及公式

LSTM

使用一個特殊的存儲記憶單元可以改善

RNN

的梯度消失問題,是以在許多自然語言處理任務中它比

RNN

有更好的性能。

LSTM

單元的基本結構如下圖所示。

标準的LSTM網絡以及公式

它由輸入門 i t i_{t} it​,忘記門 f t f_{t} ft​,輸出門 o t o_{t} ot​,以及一個記憶單元 c t c_{t} ct​組成。

給定一個文本序列 x = { x 1 , x 2 , … , x n } , x t x=\left\{x_{1}, x_{2}, \ldots, x_{n}\right\}, x_{t} x={x1​,x2​,…,xn​},xt​ 表示目前時間步 t t t的輸入, h t − 1 h_{t-1} ht−1​表示上一步

LSTM

的輸出。

LSTM

通過門更新記憶單元狀态,添加或删除資訊以保留與任務相關的内容。 i t i_{t} it​乘以候選值 u t u_{t} ut​決定了添加到記憶單元的新的輸入資訊。 f t f_{t} ft​ 乘以 h t − 1 h_{t-1} ht−1​ 決定從記憶單元狀态中移除的已有資訊。輸出門 o t o_{t} ot​決定從記憶單元狀态最終的輸出資訊。

輸入門:

i t = σ ( W ( i ) x t + U ( i ) h t − 1 + b ( i ) ) i_{t}=\sigma\left(W^{(i)} x_{t}+U^{(i)} h_{t-1}+b^{(i)}\right) it​=σ(W(i)xt​+U(i)ht−1​+b(i))

忘記門:

f t = σ ( W ( f ) x t + U ( f ) h t − 1 + b ( f ) ) f_{t}=\sigma\left(W^{(f)} x_{t}+U^{(f)} h_{t-1}+b^{(f)}\right) ft​=σ(W(f)xt​+U(f)ht−1​+b(f))

輸出門:

o t = σ ( W ( o ) x t + U ( o ) h t − 1 + b ( o ) ) o_{t}=\sigma\left(W^{(o)} x_{t}+U^{(o)} h_{t-1}+b^{(o)}\right) ot​=σ(W(o)xt​+U(o)ht−1​+b(o))

記憶單元候選值:

u t = tanh ⁡ ( W ( u ) x t + U ( u ) h t − 1 + b ( u ) ) u_{t}=\tanh \left(W^{(u)} x_{t}+U^{(u)} h_{t-1}+b^{(u)}\right) ut​=tanh(W(u)xt​+U(u)ht−1​+b(u))

記憶單元狀态更新:

c t = i t ⊙ u t + f t ⊙ c t − 1 c_{t}=i_{t} \odot u_{t}+f_{t} \odot c_{t-1} ct​=it​⊙ut​+ft​⊙ct−1​

輸出:

h t = o t ⊙ tanh ⁡ ( c t ) h_{t}=o_{t} \odot \tanh \left(c_{t}\right) ht​=ot​⊙tanh(ct​)

繼續閱讀