天天看點

pytorch自定義交叉熵損失函數

這個在label為1維的時候能對的上。

2維測試交叉熵代碼:

注意:

output 次元是[batch_size,所分類預測值,樣本數]

label次元是[batch_size,樣本數]

    output = torch.randn(3, 3,5, requires_grad=True)

    label = torch.empty((3,5), dtype=torch.long).random_(3)

import torch
import torch.nn as nn
import numpy as np


class CrossEntropyLoss(nn.Module):
    def __init__(self):
        super(CrossEntropyLoss, self).__init__()

    def forward(self, output, label):

        if label.dim()>1:
            output=output.permute(0,2,1)
            label=label.view(-1)
            output=output.reshape((label.size(0),output.size(2)))
        first = [-output[i][label[i]] for i in range(label.size()[0])]
        first_ = 0
        for i in range(len(first)):
            first_ += first[i]

        second = torch.exp(output)
        second = torch.sum(second, dim=1)