在pytorch中,scatter是一个非常实用的映射函数,其将一个源张量(source)中的值按照指定的轴方向(dim)和对应的位置关系(index)逐个填充到目标张量(target)中,其函数写法为:
其中各变量及参数的说明如下:
-
:即目标张量,将在该张量上进行映射target
-
:即源张量,将把该张量上的元素逐个映射到目标张量上src
-
:指定轴方向,定义了填充方式。对于二维张量,dim
表示逐列进行行填充,而dim=0
表示逐列进行行填充dim=1
-
: 按照轴方向,在index
张量中需要填充的位置target
为了保证scatter填充的有效性,需要注意:
(1)
target
张量在
dim
方向上的长度不小于
source
张量,且在其它轴方向的长度与
source
张量一般相同。这里的一般是指:scatter操作本身有broadcast机制。
(2)
index
张量的shape一般与
source
,从而定义了每个
source
元素的填充位置。这里的一般是指broadcast机制下的例外情况。
下面以一个实际的案例来观察scatter函数:
import torch
a = torch.arange(10).reshape(2,5).float()
b = torch.zeros(3, 5))
b_= b.scatter(dim=0, index=torch.LongTensor([[1, 2, 1, 1, 2], [2, 0, 2, 1, 0]]),src=a)
print(b_)
# tensor([[0, 6, 0, 0, 9],
# [0, 0, 2, 8, 0],
# [5, 1, 7, 0, 4]])
整个函数的操作过程见下面的示意图。因为设定了
dim=0
,所以会逐列将
source
中的元素按照
index
中的位置信息,放入
target
张量中。
scatter函数的一个典型应用就是在分类问题中,将目标标签转换为one-hot编码形式,如:
labels = torch.LongTensor([1,3])
targets = torch.zeros(2, 5)
targets.scatter(dim=1, index=labels.unsqueeze(-1), src=torch.tensor(1))
# 注意dim=1,即逐样本的进行列填充
# 返回值为 tensor([[0, 1, 0, 0, 0],
# [0, 0, 0, 1, 0]])