天天看點

二進制交叉熵損失函數

普通的交叉熵損失

二進制交叉熵損失函數

BCEWithLogitsLoss

就是在外邊複合一層sigmoid函數,

二進制交叉熵損失函數
import torch
target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
output = torch.full([10, 64], 1.5)  # A prediction (logit)
pos_weight = torch.ones([64])  # All weights are equal to 1
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target)  # -log(sigmoid(1.5))
           

驗證

import numpy as np
def sigmoid(x):
    return 1/(1+np.exp(-x))
loss = -target*np.log(sigmoid(output)) - (1-target)*np.log(1-sigmoid(output))
torch.mean(loss)
           

繼續閱讀