天天看點

深入了解PyTorch中的gather函數

gather函數

今天在用PyTorch複現softmax的時候,參考的書籍為《Dive into DL Pytorch》。在書裡面,關于gather函數原文是這樣叙述的:

上一節中,我們介紹了softmax回歸使用的交叉熵損失函數。為了得到标簽的預測機率,我們可以使用gather函數。在下面的例子中,變量y_hat是2個樣本在3個類别的預測機率,變量y是這2個樣本的标簽類别。通過使用gather函數,我們得到了2個樣本的标簽的預測機率。在代碼中,标簽類别的離散值是從0開始逐一遞增的。
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])
y_hat.gather(1, y.view(-1, 1))
           

輸出為:

tensor([[0.1000],
        [0.5000]])
           

看到這裡的時候,不太能了解兩點:

  1. torch.LongTensor是什麼樣的資料格式;
  2. gather函數到底在中間有什麼樣的作用。

1. torch.LongTensor

是以torch.LongTensor訓示的資料類型為張量,但裡面的元素為Long類型

2. gather函數的作用

很顯然在這段代碼裡面:gather函數的第一個參數’1‘,指定的是dim,即次元,也就是對哪個次元進行操作,此處為對the first dim(或者說第1軸)進行操作。第二個參數y.view(-1, 1), 其展開應該為:

torch.LongTensor([[0],
 				 [2]])
           

這裡 y.view(-1, 1) 裡面的元素應該是作為index也即索引的意思。是以這句代碼的了解就是對 y_hat 在第一軸上根據y.view(-1, 1)提供的索引進行取值,最後得到的輸出就是:

tensor([[0.1000],
        [0.5000]])
           

繼續閱讀