天天看點

pytorch—torch.tensor.scatter操作解析

了解scatter操作:

tensor_A.scatter_(dim, index, tensor_B)

: tensor_B的每個元素,都按照 index 被scatter(可以了解為填充)到目标tensor_A中。

(1) index和源tensor_B次元一緻;

(2) tensor_A一般是全零的張量,其某些特定位置的值由 tensor_B 中的值填充。

(3) 注意如何根據index選取tensor_B中的值:

對于2-D tensor:

if dim=0, tensor_A[index[i][j]][j] = tensor_B[i][j];

if dim=1, tensor_A[i][index[i][j]] = tensor_B[i][j];
           

對于3-D tensor:

if dim = 0,tensor_A[index[i][j][k]][j][k] = tensor_B[i][j][k]
if dim = 1,tensor_A[i][index[i][j][k]][k] = tensor_B[i][j][k]
if dim = 2,tensor_A[i][j][index[i][j][k]] = tensor_B[i][j][k]
           

舉例:

pytorch—torch.tensor.scatter操作解析
pytorch—torch.tensor.scatter操作解析

如果對您有幫助,麻煩點贊關注,這真的對我很重要!!!如果需要互關,請評論或者私信!

pytorch—torch.tensor.scatter操作解析