天天看点

【Pytorch】scatter函数详解

在pytorch中,scatter是一个非常实用的映射函数,其将一个源张量(source)中的值按照指定的轴方向(dim)和对应的位置关系(index)逐个填充到目标张量(target)中,其函数写法为:

其中各变量及参数的说明如下:

  • target

    :即目标张量,将在该张量上进行映射
  • src

    :即源张量,将把该张量上的元素逐个映射到目标张量上
  • dim

    :指定轴方向,定义了填充方式。对于二维张量,

    dim=0

    表示逐列进行行填充,而

    dim=1

    表示逐列进行行填充
  • index

    : 按照轴方向,在

    target

    张量中需要填充的位置

为了保证scatter填充的有效性,需要注意:

(1)

target

张量在

dim

方向上的长度不小于

source

张量,且在其它轴方向的长度与

source

张量一般相同。这里的一般是指:scatter操作本身有broadcast机制。

(2)

index

张量的shape一般与

source

,从而定义了每个

source

元素的填充位置。这里的一般是指broadcast机制下的例外情况。

下面以一个实际的案例来观察scatter函数:

import torch
a = torch.arange(10).reshape(2,5).float()
b = torch.zeros(3, 5))
b_= b.scatter(dim=0, index=torch.LongTensor([[1, 2, 1, 1, 2], [2, 0, 2, 1, 0]]),src=a)
print(b_)

# tensor([[0, 6, 0, 0, 9],
#        [0, 0, 2, 8, 0],
#        [5, 1, 7, 0, 4]])
           

整个函数的操作过程见下面的示意图。因为设定了

dim=0

,所以会逐列将

source

中的元素按照

index

中的位置信息,放入

target

张量中。

【Pytorch】scatter函数详解

scatter函数的一个典型应用就是在分类问题中,将目标标签转换为one-hot编码形式,如:

labels = torch.LongTensor([1,3])
targets = torch.zeros(2, 5)
targets.scatter(dim=1, index=labels.unsqueeze(-1), src=torch.tensor(1))
# 注意dim=1,即逐样本的进行列填充
# 返回值为 tensor([[0, 1, 0, 0, 0],
#        [0, 0, 0, 1, 0]])
           

继续阅读