怎么把tort int64转换成LongTensor?

时间:2019-06-08 21:15:11

标签: python pytorch

我正在学习一门使用不推荐使用的PyTorch版本的课程,该版本不会根据需要将torch.int64更改为torch.LongTensor。引发错误的当前代码部分是:

loss = loss_fn(Ypred, Ytrain_) # calc loss on the prediction

我相信dtype应该在此部分中进行更改:

Ytrain_ = torch.from_numpy(y_train.values).view(1, -1)[0]

使用Ytrain_.dtype测试数据类型时,它将返回torch.int64。我试图通过应用long()函数将其转换为:Ytrain_ = Ytrain_.long()无济于事。

我也尝试过在documentation中寻找它,但是似乎它说torch.int64torch.long,我认为这意味着torch.int64应该可以工作。

RuntimeError                              Traceback (most recent call last)
----> 9     loss = loss_fn(Ypred, Ytrain_) # calc loss on the prediction
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'target'

1 个答案:

答案 0 :(得分:0)

user8426627所述,您想更改张量类型,而不是数据类型。因此,解决方案是添加.type(torch.LongTensor)以将其转换为LongTensor

最终代码:

Ytrain_ = torch.from_numpy(Y_train.values).view(1, -1)[0].type(torch.LongTensor)

测试张量类型:

Ytrain_.type()

'torch.LongTensor'