Pytorch分类交叉熵损失函数行为

时间:2019-11-18 21:48:36

标签: pytorch

我对来自Pytorch的分类交叉熵损失所做的计算有疑问。 我已经制作了这个简单的代码段,并且因为我将输出张量的argmax用作目标,所以我不明白为什么损失仍然很高。

import torch
import torch.nn as nn
ce_loss = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
targets = torch.argmax(output, dim=1)
loss = ce_loss(outputs, targets)
print(loss)

感谢您的理解。 最好的祝福 杰罗姆(Jerome)

1 个答案:

答案 0 :(得分:1)

这是代码中的示例数据,其中outputlabelloss具有以下值

outputs =  tensor([[ 0.5968, -0.8249,  1.5018,  2.7888, -0.6125],
                   [-1.1534, -0.4921,  1.0688,  0.2241, -0.0257],
                   [ 0.3747,  0.8957,  0.0816,  0.0745,  0.2695]], requires_grad=True)requires_grad=True)

labels = tensor([3, 2, 1])
loss = tensor(0.7354, grad_fn=<NllLossBackward>)

所以让我们检查一下值

如果您计算登录数(outputs的softmax输出,则使用类似torch.softmax(outputs,axis=1)的东西,您将得到

probs = tensor([[0.0771, 0.0186, 0.1907, 0.6906, 0.0230],
                [0.0520, 0.1008, 0.4801, 0.2063, 0.1607],
                [0.1972, 0.3321, 0.1471, 0.1461, 0.1775]], grad_fn=<SoftmaxBackward>)

因此,这些将是您的预测概率。

现在交叉熵损失仅是softmaxnegative log likelihood loss.的组合,因此,您的损失可以简单地通过使用

进行计算
loss = (torch.log(1/probs[0,3]) +  torch.log(1/probs[1,2]) + torch.log(1/probs[2,1])) / 3

,它是真实标签的概率的负对数的平均值。上面的等式求值为0.7354,它等于从nn.CrossEntropyLoss模块返回的值。