如何在PyTorch中将double用作浮点数的默认类型

时间:2018-09-06 08:27:53

标签: python image-processing machine-learning computer-vision pytorch

默认情况下,我希望PyTorch代码中的所有浮点数都为双精度类型,怎么办?

2 个答案:

答案 0 :(得分:3)

您正在寻找torch.set_default_tensor_type

torch.set_default_tensor_type(torch.DoubleTensor)

或者,您可以使用torch.set_default_dtype

torch.set_default_dtype(torch.float64)

答案 1 :(得分:0)

您应该为此torch.set_default_dtype使用。

确实,使用torch.set_default_tensor_type也会产生类似的效果,但是torch.set_default_tensor_type不仅设置默认数据类型,而且还设置设备的默认值张量的分配位置以及张量的布局