天天看点

pytorch将标签转为onehot

由于想多分类中使用Diceloss,所以需要将[0,1,2,..N]类型的标签转化为onehot类型。

1、在cpu上处理

input数据类型: torch.LongTensor()

        数据形状:[bs, 1, *]         可为2D或3D数据      

def make_one_hot(input, num_classes):
    """Convert class index tensor to one hot encoding tensor.
    Args:
         input: A tensor of shape [bs, 1, *]
         num_classes: An int of number of class
    Returns:
        A tensor of shape [bs, num_classes, *]
    """
    shape = np.array(input.shape)
    shape[1] = num_classes
    shape = tuple(shape)
    result = torch.zeros(shape)
    result = result.scatter_(1, input.cpu(), 1)

    return result
           

2、在GPU上处理

input数据类型: torch.LongTensor().cuda()

        数据形状:[bs, 1, *]         可为2D或3D数据 

def make_one_hot(input, num_classes):
    """Convert class index tensor to one hot encoding tensor.
    Args:
         input: A tensor of shape [bs, 1, *]
         num_classes: An int of number of class
    Returns:
        A tensor of shape [bs, num_classes, *]
    """
    shape = np.array(input.shape)
    shape[1] = num_classes
    shape = tuple(shape)
    result = torch.zeros(shape).cuda()
    result = result.scatter_(1, input, 1)

    return result
           

3、最近版pytorch有直接的转化为onehot的代码。

具体我自己torch1.7可以直接使用one_hot,不知道是从哪一版开始的

"""
输入gt尺寸为*,得到one-hot结果尺寸为(*,num_class)
"""
import torch.nn.functional as F
gt_onthot = F.one_hot(gt, num_classes=n)   # n为类别数

           

F.one_hot 函数详解参看参考手册

4、温馨提示

1、FloatTensor转化为LongTensor:

# 此时的输入label为FloatTensor,可在cuda,也可是cpu
label_long = label.long()
           

2、 Tensor增加一个维度

label_onehot = label_onehot.unsqueeze(1)   #在第一维增加一个维度
           

3、多分类交叉熵是不需要将标签转为onehot的

详情请查看  https://blog.csdn.net/longshaonihaoa/article/details/105253553