3D输入的Pytorch交叉熵损失

时间:2020-08-29 15:31:50

标签: python neural-network pytorch cross-entropy

我有一个网络,该网络输出大小为(batch_size, max_len, num_classes)的3D张量。我的真实真理是(batch_size, max_len)的形式。如果我确实对标签执行一次热编码,则其形状将为(batch_size, max_len, num_classes),即max_len中的值是[0, num_classes]范围内的整数。由于原始代码太长,我编写了一个更简单的版本来再现原始错误。

criterion = nn.CrossEntropyLoss()
batch_size = 32
max_len = 350
num_classes = 1000
pred = torch.randn([batch_size, max_len, num_classes])
label = torch.randint(0, num_classes,[batch_size, max_len])
pred = nn.Softmax(dim = 2)(pred)
criterion(pred, label)

pred和label的形状分别为torch.Size([32, 350, 1000])torch.Size([32, 350])

遇到的错误是

ValueError:预期的目标大小(32,1000),得到了torch.Size([32,350,1000])

如果我对标签进行一次热编码以计算损失

x = nn.functional.one_hot(label)
criterion(pred, x)

它将引发以下错误

ValueError:预期的目标大小(32,1000),得到了torch.Size([32,350,1000])

1 个答案:

答案 0 :(得分:1)

Pytorch documentation中,CrossEntropyLoss期望其输入的形状为(N, C, ...),因此第二维始终是类的数量。如果将preds的大小调整为(batch_size, num_classes, max_len),则代码应该可以正常工作。