如何计算Pytorch中二进制分类的交叉熵损失?

时间:2017-08-25 14:49:03

标签: pytorch

对于二进制分类,我的输出和标签是这样的

output = [0.7, 0.3, 0.1, 0.9 ... ]
label = [1, 0, 0, 1 ... ]

其中输出是precited label = 1

的概率

我想要一个像这样的交叉熵:

def cross_entropy(output, label):
    return sum(-label * log(output) - (1 - label) * log(1 - output))

但是,这会给我一个NaN错误,因为log(output)中的output可能为零。

我知道torch.nn.CrossEntropyLoss但它不适用于我的数据格式。

2 个答案:

答案 0 :(得分:1)

import torch
import torch.nn.functional as F
def my_binary_cross_entrophy(output,label):
    label = label.float()
    #print(label)
    loss = 0
    for i in range(len(label)):
        loss += -(label[i]*math.log(output[i])+(1-label[i])*math.log(1-output[i]))
        #print(loss)
    return loss/len(label)

label1 = torch.randint(0,2,(3,)).float()
output = torch.rand(3)
my_binary_cross_entrophy(output,label1)

它返回的值与F.binary_cross_entropy值相同。

F.binary_cross_entropy(output,label1)

答案 1 :(得分:0)

Leonard2在对问题的评论中提到,torch.nn.BCELoss(意思是“二进制交叉熵损失”似乎正是所要的。