对于二进制分类,我的输出和标签是这样的
output = [0.7, 0.3, 0.1, 0.9 ... ]
label = [1, 0, 0, 1 ... ]
其中输出是precited label = 1
我想要一个像这样的交叉熵:
def cross_entropy(output, label):
return sum(-label * log(output) - (1 - label) * log(1 - output))
但是,这会给我一个NaN错误,因为log(output)
中的output
可能为零。
我知道torch.nn.CrossEntropyLoss
但它不适用于我的数据格式。
答案 0 :(得分:1)
import torch
import torch.nn.functional as F
def my_binary_cross_entrophy(output,label):
label = label.float()
#print(label)
loss = 0
for i in range(len(label)):
loss += -(label[i]*math.log(output[i])+(1-label[i])*math.log(1-output[i]))
#print(loss)
return loss/len(label)
label1 = torch.randint(0,2,(3,)).float()
output = torch.rand(3)
my_binary_cross_entrophy(output,label1)
它返回的值与F.binary_cross_entropy值相同。
F.binary_cross_entropy(output,label1)
答案 1 :(得分:0)
Leonard2在对问题的评论中提到,torch.nn.BCELoss
(意思是“二进制交叉熵损失”似乎正是所要的。