我有一个网络,该网络输出大小为(batch_size, max_len, num_classes)
的3D张量。我的真实真理是(batch_size, max_len)
的形式。如果我确实对标签执行一次热编码,则其形状将为(batch_size, max_len, num_classes)
,即max_len
中的值是[0, num_classes]
范围内的整数。由于原始代码太长,我编写了一个更简单的版本来再现原始错误。
criterion = nn.CrossEntropyLoss()
batch_size = 32
max_len = 350
num_classes = 1000
pred = torch.randn([batch_size, max_len, num_classes])
label = torch.randint(0, num_classes,[batch_size, max_len])
pred = nn.Softmax(dim = 2)(pred)
criterion(pred, label)
pred和label的形状分别为torch.Size([32, 350, 1000])
和torch.Size([32, 350])
遇到的错误是
ValueError:预期的目标大小(32,1000),得到了torch.Size([32,350,1000])
如果我对标签进行一次热编码以计算损失
x = nn.functional.one_hot(label)
criterion(pred, x)
它将引发以下错误
ValueError:预期的目标大小(32,1000),得到了torch.Size([32,350,1000])
答案 0 :(得分:1)
在Pytorch documentation中,CrossEntropyLoss
期望其输入的形状为(N, C, ...)
,因此第二维始终是类的数量。如果将preds
的大小调整为(batch_size, num_classes, max_len)
,则代码应该可以正常工作。