天天看點

Pytorch 類别标簽轉換one-hot編碼

這裡用到了Pytorch的scatter_函數:

scatter_(dim, index, src) → Tensor

Writes all values from the tensor

src

into

self

at the indices specified in the

index

tensor. For each value in

src

, its output index is specified by its index in

src

for

dimension != dim

and by the corresponding value in

index

for

dimension = dim

.

For a 3-D tensor,

self

is updated as:

self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0

self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1

self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2

注意要保證self和index次元一緻

對于分類問題,标簽可以是類别索引值也可以是one-hot表示。以10類别分類為例,lable=[3] 和label=[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]是一緻的.

>>>class_num = 10
>>>batch_size = 4
>>>label = torch.LongTensor(batch_size, 1).random_() % class_num
 3
 0
 0
 8

>>>one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1)
    0     0     0     1     0     0     0     0     0     0
    1     0     0     0     0     0     0     0     0     0
    1     0     0     0     0     0     0     0     0     0
    0     0     0     0     0     0     0     0     1     0
           

自己做圖像分割時,把圖像标簽轉成one-hot編碼形式,原圖像(1,512,512),生成one-hot(class_nums, 512, 512):

gt_onehot = torch.zeros((class_nums, gt.shape[1], gt.shape[2]))
gt_onehot.scatter_(0, gt.long(), 1)
           

參考:

1.https://pytorch.org/docs/stable/tensors.html?highlight=scatter_#torch.Tensor.scatter_

2.PyTorch——Tensor_把索引标簽轉換成one-hot标簽表示

繼續閱讀