如何使用Pytorch中的交叉熵损失进行二进制预测

时间:2018-08-18 00:44:15

标签: pytorch

在pytorch文档中,它表示交叉熵损失:

  

输入必须为张量大小(minibatch,C)

这是否意味着对于二进制(0,1)预测,必须将输入转换为第二维等于(1-p)的(N,2)张量?

例如,如果我为目标为1(真)的类预测0.75,我是否必须将两个值(0.75; 0.25)叠加在一起作为输入?

1 个答案:

答案 0 :(得分:0)

快速简便:是的,仅将1.0作为真实类,将0.0作为其他类作为目标值。您的模型还应该针对这种情况生成两个预测,尽管可以仅使用一个预测就可以做到这一点,并使用符号信息来确定类别。在这种情况下,使用softmax作为最后一个操作将不会有概率,而是例如使用Sigmoid函数(它将您的输出从(-inf,inf)映射到(0,1)。