我恢复了错误
Expected object of scalar type Long but got scalar type Int for argument #3 'index'
这是从此行。
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
我不确定该怎么做,因为我尝试使用多个位置将其转换为很长的时间。我试图放一个
.long
最后,将dtype设置为torch.long仍然无法正常工作。
与此非常相似,但是他没有做任何事情来得到答案 "Expected Long but got Int" while running PyTorch script
答案 0 :(得分:0)
您的索引参数的dtype(即targets.unsqueeze(1).data.cpu()
)必须为torch.int64
。
(错误消息有点令人困惑:torch.long
不存在。但是PyTorch内部的“长”表示int64)。
答案 1 :(得分:0)
targets = torch.zeros(log_probs.size()).scatter_(1, (targets.unsqueeze(1).data.cpu()).long(), 1)