天天看點

[PyTorch] rnn,lstm,gru中輸入輸出次元

本文中的RNN泛指LSTM,GRU等等

CNN中和RNN中batchSize的預設位置是不同的。

**CNN中**:batchsize的位置是position 0.
**RNN中**:batchsize的位置是position 1.
           

在RNN中輸入資料格式:

對于最簡單的RNN,我們可以使用兩種方式來調用,torch.nn.RNNCell(),它隻接受序列中的單步輸入,必須顯式的傳入隐藏狀态。torch.nn.RNN()可以接受一個序列的輸入,預設會傳入一個全0的隐藏狀态,也可以自己申明隐藏狀态傳入。

輸入大小是**三維**tensor[seq_len,batch_size,input_dim]

input_dim是輸入的次元,比如是128
batch_size是一次往RNN輸入句子的數目,比如是5。
seq_len是一個句子的最大長度,比如15
**是以千萬注意,RNN輸入的是序列,一次把批次的所有句子都輸入了,得到的ouptut和hidden都是這個批次的所有的輸出和隐藏狀态,次元也是三維。**
**可以了解為現在一共有batch_size個獨立的RNN元件,RNN的輸入次元是input_dim,總共輸入seq_len個時間步,則每個時間步輸入到這個整個RNN子產品的次元是[batch_size,input_dim]
           
# 構造RNN網絡,x的次元5,隐層的次元10,網絡的層數2
rnn_seq = nn.RNN(5, 10,2)  
# 構造一個輸入序列,句長為 6,batch 是 3, 每個單詞使用長度是 5的向量表示
x = torch.randn(6, 3, 5)
#out,ht = rnn_seq(x,h0) 
out,ht = rnn_seq(x) #h0可以指定或者不指定

           

問題1:這裡out、ht的size是多少呢?

回答:out:6 * 3 * 10, ht: 2 * 3 * 10,out的輸出次元[seq_len,batch_size,output_dim],ht的次元[num_layers * num_directions, batch, hidden_size],如果是單向單層的RNN那麼一個句子隻有一個hidden。

問題2:out[-1]和ht[-1]是否相等?

回答:相等,隐藏單元就是輸出的最後一個單元,可以想象,每個的輸出其實就是那個時間步的隐藏單元

LSTM的輸出多了一個memory單元

# 輸入次元 50,隐層100維,兩層
lstm_seq = nn.LSTM(50, 100, num_layers=2)
# 輸入序列seq= 10,batch =3,輸入次元=50
lstm_input = torch.randn(10, 3, 50)
out, (h, c) = lstm_seq(lstm_input) # 使用預設的全 0 隐藏狀态

           

問題1:out和(h,c)的size各是多少?

回答:out:(10 * 3 * 100),(h,c):都是(2 * 3 * 100)

問題2:out[-1,:,:]和h[-1,:,:]相等嗎?

回答: 相等

GRU比較像傳統的RNN

gru_seq = nn.GRU(10, 20,2) # x_dim,h_dim,layer_num
gru_input = torch.randn(3, 32, 10) # seq,batch,x_dim
out, h = gru_seq(gru_input)