我对来自Pytorch的分类交叉熵损失所做的计算有疑问。 我已经制作了这个简单的代码段,并且因为我将输出张量的argmax用作目标,所以我不明白为什么损失仍然很高。
import torch
import torch.nn as nn
ce_loss = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
targets = torch.argmax(output, dim=1)
loss = ce_loss(outputs, targets)
print(loss)
感谢您的理解。 最好的祝福 杰罗姆(Jerome)
答案 0 :(得分:1)
这是代码中的示例数据,其中output
,label
和loss
具有以下值
outputs = tensor([[ 0.5968, -0.8249, 1.5018, 2.7888, -0.6125],
[-1.1534, -0.4921, 1.0688, 0.2241, -0.0257],
[ 0.3747, 0.8957, 0.0816, 0.0745, 0.2695]], requires_grad=True)requires_grad=True)
labels = tensor([3, 2, 1])
loss = tensor(0.7354, grad_fn=<NllLossBackward>)
所以让我们检查一下值
如果您计算登录数(outputs
的softmax输出,则使用类似torch.softmax(outputs,axis=1)
的东西,您将得到
probs = tensor([[0.0771, 0.0186, 0.1907, 0.6906, 0.0230],
[0.0520, 0.1008, 0.4801, 0.2063, 0.1607],
[0.1972, 0.3321, 0.1471, 0.1461, 0.1775]], grad_fn=<SoftmaxBackward>)
因此,这些将是您的预测概率。
现在交叉熵损失仅是softmax
和negative log likelihood loss.
的组合,因此,您的损失可以简单地通过使用
loss = (torch.log(1/probs[0,3]) + torch.log(1/probs[1,2]) + torch.log(1/probs[2,1])) / 3
,它是真实标签的概率的负对数的平均值。上面的等式求值为0.7354
,它等于从nn.CrossEntropyLoss
模块返回的值。