承接上篇SimpleRNN, PyTorch中對于LSTM也有兩個方法,即nn.LSTM和nn.LSTMCell。同樣地,我們用兩種方法來做一個簡單例子的前饋。
先來看LSTMCell,執行個體化用到的參數如下:
from torch import nn
torch.nn.LSTMCell(input_size: int, hidden_size: int, bias: bool = True)
下面是官方文檔中對于公式的說明以及參數的說明。
請注意:執行個體化後的LSTM(或LSTMCell)對象,其權重是 i , f , g , o i,f,g,o i,f,g,o這四個矩陣的拼接,且其拼接順序也是 i → f → g → o i\rightarrow f\rightarrow g\rightarrow o i→f→g→o
這次我用的是台大李宏毅老師2020機器學習深度學習課程的一個例子,并且我人為做了一些改動。
規則是這樣的: x 2 = 1 x_2=1 x2=1,則更新記憶; x 2 = − 1 x_2=-1 x2=−1,則重置記憶; x 3 = 1 x_3=1 x3=1,則輸出記憶。對于激活函數,老師在三個門控用的是sigmoid,把輸入和輸出的tanh換成了線性激活(也就是原樣輸出)。
在PyTorch裡面似乎是不能人為指定非線性激活函數的,是以隻能用tanh函數作為輸入和輸出時的激活。觀察 tanh \tanh tanh函數的圖像,我們可以發現,在 [ − 0.25 , 0.25 ] [-0.25,0.25] [−0.25,0.25]這個區間裡 tanh \tanh tanh函數近似可以視作線性函數并且滿足 y = x y=x y=x。
是以我們把老師PPT上的輸入序列稍微變一下,我們要保證 x 1 x_1 x1在 [ − 0.25 , 0.25 ] [-0.25,0.25] [−0.25,0.25]這個區間裡。
x 1 x_1 x1 | 0.2 | 0.1 | -0.1 | -0.2 | 0.25 |
---|---|---|---|---|---|
x 2 x_2 x2 | 1 | 1 | -1 | ||
x 3 x_3 x3 | 1 |
現在就讓我們手動運作一下LSTMCell吧!
t t t | 1 | 2 | 3 | 4 | 5 |
---|---|---|---|---|---|
i i i | σ ( 90 ) ≈ 1 \sigma(90)\approx1 σ(90)≈1 | σ ( 90 ) ≈ 1 \sigma(90)\approx1 σ(90)≈1 | σ ( − 10 ) ≈ 0 \sigma(-10)\approx0 σ(−10)≈0 | ||
f f f | σ ( 110 ) ≈ 1 \sigma(110)\approx1 σ(110)≈1 | σ ( 110 ) ≈ 1 \sigma(110)\approx1 σ(110)≈1 | σ ( 10 ) ≈ 1 \sigma(10)\approx1 σ(10)≈1 | 1 | |
g g g | tanh ( 0.2 ) ≈ 0.2 \tanh(0.2)\approx0.2 tanh(0.2)≈0.2 | tanh ( 0.1 ) ≈ 0.1 \tanh(0.1)\approx0.1 tanh(0.1)≈0.1 | tanh ( − 0.1 ) ≈ − 0.1 \tanh(-0.1)\approx-0.1 tanh(−0.1)≈−0.1 | -0.2 | 0.25 |
o o o | σ ( − 10 ) ≈ 0 \sigma(-10)\approx0 σ(−10)≈0 | σ ( − 10 ) ≈ 0 \sigma(-10)\approx0 σ(−10)≈0 | σ ( − 10 ) ≈ 0 \sigma(-10)\approx0 σ(−10)≈0 | 1 | |
c c c | 0 × f + g × i = 0.2 0\times f+g\times i=0.2 0×f+g×i=0.2 | 0.2 × f + g × i = 0.3 0.2\times f+g\times i=0.3 0.2×f+g×i=0.3 | 0.3 × f + g × i = 0.3 0.3\times f+g\times i=0.3 0.3×f+g×i=0.3 | 0.3 | |
h h h | tanh ( c × o ) = 0 \tanh(c\times o)=0 tanh(c×o)=0 | tanh ( c × o ) = 0 \tanh(c\times o)=0 tanh(c×o)=0 | tanh ( c × o ) = 0 \tanh(c\times o)=0 tanh(c×o)=0 | 0.3 |
記住這裡的 c ( i ) , h ( i ) c^{(i)},h^{(i)} c(i),h(i),馬上我們将拿他們與PyTorch運算結果對比。
import torch
from torch import nn
from torch.autograd import Variable
batch_size = 1
seq = 5
input_size, hidden = 3, 1
lstm_cell = nn.LSTMCell(input_size=input_size, hidden_size=hidden, bias=True)
lstm_cell.weight_ih.data = torch.Tensor([[0, 100, 0], [0, 100, 0],
[1, 0, 0], [0, 0, 100]]) # 1
lstm_cell.weight_hh.data = torch.zeros(4, 1) # 2
lstm_cell.bias_ih.data = torch.Tensor([-10, 10, 0, -10]) # 3
lstm_cell.bias_hh.data = torch.zeros(4) # 4
x = Variable(torch.Tensor([[[0.2, 1, 0]],
[[0.1, 1, 0]],
[[-0.1, 0, 0]],
[[-0.2, 0, 1]],
[[0.25, -1, 0]]]))
h_n = Variable(torch.zeros(1, 1))
c_n = h_n.clone()
for step in range(seq):
h_n, c_n = lstm_cell(x[step], (h_n, c_n))
print('t=%d' % step)
print('c=%.1f' % c_n.data)
print('h=%.1f' % h_n.data)
print('-' * 40)
nn.LSTM()
再次強調一下,執行個體化後的LSTM(或LSTMCell)對象,其權重是 i , f , g , o i,f,g,o i,f,g,o這四個矩陣的拼接,且其拼接順序也是 i → f → g → o i\rightarrow f\rightarrow g\rightarrow o i→f→g→o,在#1處與#3處我是嚴格按照這個順序指派的。
在#2處和#4處,由于我們規則裡目前時刻的行動(更新/重置/輸出)隻取決于目前時刻輸入而與曆史輸入無關,是以理應給 W h i , W h f , W h g , W h o W_{hi},W_{hf},W_{hg},W_{ho} Whi,Whf,Whg,Who這些權重以及 b h i , b h f , b h g , b h o b_{hi},b_{hf},b_{hg},b_{ho} bhi,bhf,bhg,bho這些偏置置零。
看一下運作結果:
如果我們用更簡單的LSTM而不是LSTMCell:
batch_size = 1
seq = 5
input_size, hidden = 3, 1
lstm = nn.LSTM(input_size=input_size, hidden_size=hidden, bias=True, num_layers=1)
lstm.weight_ih_l0.data = torch.Tensor([[0, 100, 0], [0, 100, 0],
[1, 0, 0], [0, 0, 100]])
lstm.weight_hh_l0.data = torch.zeros(4, 1)
lstm.bias_ih_l0.data = torch.Tensor([-10, 10, 0, -10])
lstm.bias_hh_l0.data = torch.zeros(4)
x = Variable(torch.Tensor([[[0.2, 1, 0]],
[[0.1, 1, 0]],
[[-0.1, 0, 0]],
[[-0.2, 0, 1]],
[[0.25, -1, 0]]]))
h_0 = Variable(torch.zeros(1, 1, 1)) # 1
c_0 = h_0.clone()
output, (h_n, c_n) = lstm(x, (h_0, c_0))
torch.set_printoptions(precision=1, sci_mode=False)
print('this is output:\n', output.data)
print('this is c_n:\n', c_n.data)
print('this is h_n:\n', h_n.data)
#1處的0初始化有沒有都可以,如果不對 h ( 0 ) h^{(0)} h(0)初始化的話,預設值也是零向量或零矩陣。
運作結果:
c_n和h_n用于下一時間步輸入,雖然這裡已經結束了。
我們可以看到,這三次結果是相吻合的✌