由于想多分類中使用Diceloss,是以需要将[0,1,2,..N]類型的标簽轉化為onehot類型。
1、在cpu上處理
input資料類型: torch.LongTensor()
資料形狀:[bs, 1, *] 可為2D或3D資料
def make_one_hot(input, num_classes):
"""Convert class index tensor to one hot encoding tensor.
Args:
input: A tensor of shape [bs, 1, *]
num_classes: An int of number of class
Returns:
A tensor of shape [bs, num_classes, *]
"""
shape = np.array(input.shape)
shape[1] = num_classes
shape = tuple(shape)
result = torch.zeros(shape)
result = result.scatter_(1, input.cpu(), 1)
return result
2、在GPU上處理
input資料類型: torch.LongTensor().cuda()
資料形狀:[bs, 1, *] 可為2D或3D資料
def make_one_hot(input, num_classes):
"""Convert class index tensor to one hot encoding tensor.
Args:
input: A tensor of shape [bs, 1, *]
num_classes: An int of number of class
Returns:
A tensor of shape [bs, num_classes, *]
"""
shape = np.array(input.shape)
shape[1] = num_classes
shape = tuple(shape)
result = torch.zeros(shape).cuda()
result = result.scatter_(1, input, 1)
return result
3、最近版pytorch有直接的轉化為onehot的代碼。
具體我自己torch1.7可以直接使用one_hot,不知道是從哪一版開始的
"""
輸入gt尺寸為*,得到one-hot結果尺寸為(*,num_class)
"""
import torch.nn.functional as F
gt_onthot = F.one_hot(gt, num_classes=n) # n為類别數
F.one_hot 函數詳解參看參考手冊
4、溫馨提示
1、FloatTensor轉化為LongTensor:
# 此時的輸入label為FloatTensor,可在cuda,也可是cpu
label_long = label.long()
2、 Tensor增加一個次元
label_onehot = label_onehot.unsqueeze(1) #在第一維增加一個次元
3、多分類交叉熵是不需要将标簽轉為onehot的
詳情請檢視 https://blog.csdn.net/longshaonihaoa/article/details/105253553