天天看點

dice損失函數

def dice_coefficient(y_true_cls, y_pred_cls,
                     training_mask):
    '''
    dice loss
    :param y_true_cls:
    :param y_pred_cls:
    :param training_mask:
    :return:
    '''
    eps = 1e-5
    intersection = torch.sum(y_true_cls * y_pred_cls * training_mask)
    union = torch.sum(y_true_cls * training_mask) + torch.sum(y_pred_cls * training_mask) + eps
    loss = 1. - (2 * intersection / union)

    return loss
           

今天使用dice損失函數,發現dice損失是負值,也就是說(2 * intersection / union)的值大于1。這絕對是錯的。後續發現在最後一個卷積中使用的是nn.ReLU(inplace=True)激活函數。relu激活函數僅對負值進行歸零操作,對正值不處理。是以會出現像素值的預測機率大于1的情況。是以換為,nn.Sigmoid()解決該問題。