Pytorch预期类型为Long,但类型为int

时间:2019-05-31 08:12:33

标签: python python-3.x pytorch tensor

我恢复了错误

 Expected object of scalar type Long but got scalar type Int for argument #3 'index'

这是从此行。

targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)

我不确定该怎么做,因为我尝试使用多个位置将其转换为很长的时间。我试图放一个

.long

最后,将dtype设置为torch.long仍然无法正常工作。

与此非常相似,但是他没有做任何事情来得到答案 "Expected Long but got Int" while running PyTorch script

2 个答案:

答案 0 :(得分:0)

您的索引参数的dtype(即targets.unsqueeze(1).data.cpu())必须为torch.int64

(错误消息有点令人困惑:torch.long不存在。但是PyTorch内部的“长”表示int64)。

答案 1 :(得分:0)

targets = torch.zeros(log_probs.size()).scatter_(1, (targets.unsqueeze(1).data.cpu()).long(), 1)