天天看點

Pytorch nn.BCEWithLogitsLoss()的簡單了解與用法

這個東西,本質上和nn.BCELoss()沒有差別,隻是在BCELoss上加了個logits函數(也就是sigmoid函數),例子如下:

import torch
import torch.nn as nn

label = torch.Tensor([1, 1, 0])
pred = torch.Tensor([3, 2, 1])
pred_sig = torch.sigmoid(pred)
loss = nn.BCELoss()
print(loss(pred_sig, label))

loss = nn.BCEWithLogitsLoss()
print(loss(pred, label))

loss = nn.BCEWithLogitsLoss()
print(loss(pred_sig, label))
           

輸出結果分别為:

tensor(0.4963)
tensor(0.4963)
tensor(0.5990)
           

可以看到,nn.BCEWithLogitsLoss()相當于是在nn.BCELoss()中預測結果pred的基礎上先做了個sigmoid,然後繼續正常算loss。是以這就涉及到一個比較奇葩的bug,如果網絡本身在輸出結果的時候已經用sigmoid去處理了,算loss的時候用nn.BCEWithLogitsLoss()…那麼就會相當于預測結果算了兩次sigmoid,可能會出現各種奇奇怪怪的問題——

比如網絡收斂不了(流淚貓貓頭.jpg)

Ref

[1] https://zhuanlan.zhihu.com/p/170558960

繼續閱讀