這裡用到了Pytorch的scatter_函數:
scatter_(dim, index, src) → Tensor
Writes all values from the tensorinto
src
at the indices specified in the
self
tensor. For each value in
index
, its output index is specified by its index in
src
for
src
and by the corresponding value in
dimension != dim
for
index
dimension = dim
.
For a 3-D tensor,
is updated as:
self
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
注意要保證self和index次元一緻
對于分類問題,标簽可以是類别索引值也可以是one-hot表示。以10類别分類為例,lable=[3] 和label=[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]是一緻的.
>>>class_num = 10
>>>batch_size = 4
>>>label = torch.LongTensor(batch_size, 1).random_() % class_num
3
0
0
8
>>>one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1)
0 0 0 1 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 1 0
自己做圖像分割時,把圖像标簽轉成one-hot編碼形式,原圖像(1,512,512),生成one-hot(class_nums, 512, 512):
gt_onehot = torch.zeros((class_nums, gt.shape[1], gt.shape[2]))
gt_onehot.scatter_(0, gt.long(), 1)
參考:
1.https://pytorch.org/docs/stable/tensors.html?highlight=scatter_#torch.Tensor.scatter_
2.PyTorch——Tensor_把索引标簽轉換成one-hot标簽表示