PyTorch中的交叉熵损失的输入维

时间:2020-04-29 12:04:53

标签: pytorch loss index-error

对于batch_size = 1的二进制分类问题,我需要使用logit和label值来计算损失。

logit: tensor([0.1198, 0.1911], device='cuda:0', grad_fn=<AddBackward0>)
label: tensor(1], device='cuda:0')
# calculate loss
loss_criterion = nn.CrossEntropyLoss()
loss_criterion.cuda()
loss = loss_criterion( b_logits, b_labels )

但是,这总是会导致以下错误,

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

CrossEntropyLoss实际要求什么输入尺寸?

1 个答案:

答案 0 :(得分:1)

您传递的张量形状错误。
shape应该是(from doc

  • 输入: (N,C) 其中C =类数
  • 目标:(N),其中每个值为0 ≤ targets[i] ≤ C−1

因此,在这里,b_logits的形状应为([1,2])而不是([2]),以使其正确显示形状,您可以像b_logits.view(1,-1)一样使用torch.view

并且b_labels的形状应为([1])
例如:

b_logits = torch.tensor([0.1198, 0.1911], requires_grad=True)
b_labels = torch.tensor([1])
loss_criterion = nn.CrossEntropyLoss()

loss = loss_criterion( b_logits.view(1,-1), b_labels )
loss
tensor(0.6581, grad_fn=<NllLossBackward>)