gather函數
今天在用PyTorch複現softmax的時候,參考的書籍為《Dive into DL Pytorch》。在書裡面,關于gather函數原文是這樣叙述的:
上一節中,我們介紹了softmax回歸使用的交叉熵損失函數。為了得到标簽的預測機率,我們可以使用gather函數。在下面的例子中,變量y_hat是2個樣本在3個類别的預測機率,變量y是這2個樣本的标簽類别。通過使用gather函數,我們得到了2個樣本的标簽的預測機率。在代碼中,标簽類别的離散值是從0開始逐一遞增的。
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])
y_hat.gather(1, y.view(-1, 1))
輸出為:
tensor([[0.1000],
[0.5000]])
看到這裡的時候,不太能了解兩點:
- torch.LongTensor是什麼樣的資料格式;
- gather函數到底在中間有什麼樣的作用。
1. torch.LongTensor
是以torch.LongTensor訓示的資料類型為張量,但裡面的元素為Long類型
2. gather函數的作用
很顯然在這段代碼裡面:gather函數的第一個參數’1‘,指定的是dim,即次元,也就是對哪個次元進行操作,此處為對the first dim(或者說第1軸)進行操作。第二個參數y.view(-1, 1), 其展開應該為:
torch.LongTensor([[0],
[2]])
這裡 y.view(-1, 1) 裡面的元素應該是作為index也即索引的意思。是以這句代碼的了解就是對 y_hat 在第一軸上根據y.view(-1, 1)提供的索引進行取值,最後得到的輸出就是:
tensor([[0.1000],
[0.5000]])