天天看点

二元交叉熵损失函数

普通的交叉熵损失

二元交叉熵损失函数

BCEWithLogitsLoss

就是在外边复合一层sigmoid函数,

二元交叉熵损失函数
import torch
target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
output = torch.full([10, 64], 1.5)  # A prediction (logit)
pos_weight = torch.ones([64])  # All weights are equal to 1
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target)  # -log(sigmoid(1.5))
           

验证

import numpy as np
def sigmoid(x):
    return 1/(1+np.exp(-x))
loss = -target*np.log(sigmoid(output)) - (1-target)*np.log(1-sigmoid(output))
torch.mean(loss)
           

继续阅读