天天看點

torch.nn gather用法

官方解釋

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

Gathers values along an axis specified by dim.

For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0

out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1

out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

Parameters

input (Tensor) – the source tensor

dim (int) – the axis along which to index

index (LongTensor) – the indices of elements to gather

sparse_grad (bool,optional) – If True, gradient w.r.t. input will be a sparse tensor.

out (Tensor, optional) – the destination tensor

Example:

t = torch.tensor([[1,2],[3,4]])

torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))

output:

tensor([[ 1, 1],

[ 4, 3]])

了解為在dim的方向,取index指定的input對應的值出來,

如example中,dim=1,是以沿列的方向,分别取t第一行的[0, 0]位置,即t對應的第一行[1,1],

取第二行的[1,0]位置,對應t的[4,3]

可用在分類問題中,計算每個類别softmax機率後,直接用groudtruth類别取出對應類别計算的機率。

繼續閱讀