網絡訓練的時候,都會遇到一些常用的loss函數,很多常用的loss函數被封裝的很好,但是我在使用的時候,總是覺得像黑盒子,知道函數的大概形式,有些細節不了解,是以挑了幾個常用的loss函數進行了重新,這樣能夠更深刻的了解。
另外,很多在loss層面上進行改進的論文,例如GIOU, Focalloss以及GHM_loss,如果基本loss都不是很了解的話,這些改進的loss的paper讀起來也很艱難。
1、F.smooth_l1_loss
import torch
import torch.nn.functional as F
# 自己設計的smooth_l1_loss
def smooth_l1_loss(a, b):
loss_part1 = torch.abs(a - b)
loss_part2 = loss_part1 ** 2
loss_part2 = loss_part2 * 0.50
loss2 = torch.where(loss_part1 >= 1, loss_part1 - 0.5, loss_part2)
#下面是統計每個預測框的loss
#loss2 = torch.sum(loss2, dim = 1)
#最終傳回的是所有預測框的loss
loss2 = torch.sum(loss2)
return loss2
def test_smmoth_l1_loss():
loc_p = torch.tensor([1, 5, 3, 0.5])
loc_t = torch.tensor([4, 1, 0, 0.4])
loss_1 = F.smooth_l1_loss(loc_p, loc_t, size_average=False)
print ("F.smooth_l1_loss:", loss_1)
loss_2 = smooth_l1_loss(loc_p, loc_t)
print ("smooth_l1_loss:", loss_2)
輸出結果為:
F.smooth_l1_loss: tensor(8.5050)
smooth_l1_loss: tensor(8.5050)
2、F.cross_entropy
def cross_entropy(input, target):
loss_part1 = torch.sum(input.exp(), dim = 1).log()
y_onehot = torch.zeros(input.shape, dtype=input.dtype)
# (int dim, Tensor index, Number value)
y_onehot.scatter_(1, target.unsqueeze(1), 1)
loss_part2 = torch.sum(input * y_onehot, dim = 1) * -1
loss = loss_part1 + loss_part2
return torch.mean(loss)
def test_cross_entropy():
input = torch.tensor([[-1.4694, -2.2030, 2.4750],
[-1.0823, -0.5950, -1.4115]])
target = torch.tensor([0, 2])
loss_1 = F.cross_entropy(input, target)
print("F.cross_entropy:", loss_1)
loss_2 = cross_entropy(input, target)
print("cross_entropy:", loss_2)
test_cross_entropy()
F.cross_entropy: tensor(2.7550)
cross_entropy: tensor(2.7550)
3、F.binary_cross_entropy
def binary_cross_entropy(input, target):
bce = -(target * torch.log(input) + (1.0 - target) * torch.log(1.0 - input))
return torch.mean(bce)
def test_binary_cross_entropy():
input = torch.tensor([[1.4271, -1.8701],
[-1.1962, -2.0440],
[-0.4560, -1.4295]])
target = torch.tensor([[1., 0.],
[1., 0.],
[0., 1.]])
loss_1 = F.binary_cross_entropy(F.sigmoid(input), target)
print("F.binary_cross_entropy:", loss_1)
loss_2 = binary_cross_entropy(F.sigmoid(input), target)
print("binary_cross_entropy:", loss_2)
F.binary_cross_entropy: tensor(0.6793)
binary_cross_entropy: tensor(0.6793)