如何在LibTorch(C ++)中将Torch :: Torch的类型从Float更改为Long

时间:2019-12-08 16:57:26

标签: c++ libtorch

我正在使用LibTorch(PyTorch C ++ API)在C ++中进行编码。 在这里,我传递了均为大小{1,1}的Torch :: Tensor的预测值和目标值。

torch::Tensor loss = torch::nll_loss(predicted_value, target_value);

当我尝试评估以上内容时,出现以下错误:

 0.4997 [ Variable[CPUFloatType]{1,1} ]   # printout of predicted_value
-0.5392 [ Variable[CPUFloatType]{1,1} ]   # printout of target_value
terminate called after throwing an instance of 'c10::Error'
  what():  Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward (checked_dense_tensor_unwrap at ../../aten/src/ATen/Utils.h:84)

我尝试搜索如何将浮点型张量转换为长型张量,但只能找到Python的文档。非常感谢解决此问题的建议!

2 个答案:

答案 0 :(得分:0)

tensor.to(torch::kLong)为您提供Long类型。

这是Tensor的{​​{1}}函数的重载定义:

to

答案 1 :(得分:-1)

inline Tensor Tensor::to(ScalarType dtype, bool non_blocking, bool copy) const {
    static auto table = globalATenDispatch().getOpTable("aten::to(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False) -> Tensor");
    return table->getOp<Tensor (const Tensor &, ScalarType, bool, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dtype, non_blocking, copy);
}