官方解釋
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
Parameters
input (Tensor) – the source tensor
dim (int) – the axis along which to index
index (LongTensor) – the indices of elements to gather
sparse_grad (bool,optional) – If True, gradient w.r.t. input will be a sparse tensor.
out (Tensor, optional) – the destination tensor
Example:
t = torch.tensor([[1,2],[3,4]])
torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
output:
tensor([[ 1, 1],
[ 4, 3]])
了解為在dim的方向,取index指定的input對應的值出來,
如example中,dim=1,是以沿列的方向,分别取t第一行的[0, 0]位置,即t對應的第一行[1,1],
取第二行的[1,0]位置,對應t的[4,3]
可用在分類問題中,計算每個類别softmax機率後,直接用groudtruth類别取出對應類别計算的機率。