天天看點

F.smooth_l1_loss, F.cross_entropy ,F.binary_cross_entropy 計算細節的探究

  網絡訓練的時候,都會遇到一些常用的loss函數,很多常用的loss函數被封裝的很好,但是我在使用的時候,總是覺得像黑盒子,知道函數的大概形式,有些細節不了解,是以挑了幾個常用的loss函數進行了重新,這樣能夠更深刻的了解。

  另外,很多在loss層面上進行改進的論文,例如GIOU, Focalloss以及GHM_loss,如果基本loss都不是很了解的話,這些改進的loss的paper讀起來也很艱難。

1、F.smooth_l1_loss

F.smooth_l1_loss, F.cross_entropy ,F.binary_cross_entropy 計算細節的探究
import torch
import torch.nn.functional as F
# 自己設計的smooth_l1_loss
def smooth_l1_loss(a, b):
    loss_part1 = torch.abs(a - b)
    loss_part2 = loss_part1 ** 2
    loss_part2 = loss_part2 * 0.50
    loss2 = torch.where(loss_part1 >= 1, loss_part1 - 0.5, loss_part2)
    #下面是統計每個預測框的loss
    #loss2 = torch.sum(loss2, dim = 1) 
    #最終傳回的是所有預測框的loss
    loss2 = torch.sum(loss2)
    return loss2

def test_smmoth_l1_loss():
    loc_p = torch.tensor([1, 5, 3, 0.5])
    loc_t = torch.tensor([4, 1, 0, 0.4])
    loss_1 = F.smooth_l1_loss(loc_p, loc_t, size_average=False)
    print ("F.smooth_l1_loss:", loss_1)
    loss_2 = smooth_l1_loss(loc_p, loc_t)
    print ("smooth_l1_loss:", loss_2)
           

輸出結果為:

F.smooth_l1_loss: tensor(8.5050)

smooth_l1_loss: tensor(8.5050)

2、F.cross_entropy

F.smooth_l1_loss, F.cross_entropy ,F.binary_cross_entropy 計算細節的探究
def cross_entropy(input, target):
    loss_part1 = torch.sum(input.exp(), dim = 1).log()
    y_onehot = torch.zeros(input.shape, dtype=input.dtype)
    # (int dim, Tensor index, Number value)
    y_onehot.scatter_(1, target.unsqueeze(1), 1)
    loss_part2 = torch.sum(input * y_onehot, dim = 1) * -1
    loss = loss_part1 + loss_part2
    return torch.mean(loss)

def test_cross_entropy():
    input = torch.tensor([[-1.4694, -2.2030, 2.4750],
                         [-1.0823, -0.5950, -1.4115]])
    target = torch.tensor([0, 2])
    loss_1 = F.cross_entropy(input, target)
    print("F.cross_entropy:", loss_1)
    loss_2 = cross_entropy(input, target)
    print("cross_entropy:", loss_2)

test_cross_entropy()
           

F.cross_entropy: tensor(2.7550)

cross_entropy: tensor(2.7550)

3、F.binary_cross_entropy

def binary_cross_entropy(input, target):
    bce = -(target * torch.log(input) + (1.0 - target) * torch.log(1.0 - input))
    return torch.mean(bce)

def test_binary_cross_entropy():
    input = torch.tensor([[1.4271, -1.8701],
                          [-1.1962, -2.0440],
                          [-0.4560, -1.4295]])
    target = torch.tensor([[1., 0.],
                           [1., 0.],
                           [0., 1.]])

    loss_1 = F.binary_cross_entropy(F.sigmoid(input), target)
    print("F.binary_cross_entropy:", loss_1)
    loss_2 = binary_cross_entropy(F.sigmoid(input), target)
    print("binary_cross_entropy:", loss_2)

           

F.binary_cross_entropy: tensor(0.6793)

binary_cross_entropy: tensor(0.6793)

繼續閱讀