默认情况下,我希望PyTorch代码中的所有浮点数都为双精度类型,怎么办?
答案 0 :(得分:3)
您正在寻找torch.set_default_tensor_type
:
torch.set_default_tensor_type(torch.DoubleTensor)
或者,您可以使用torch.set_default_dtype
:
torch.set_default_dtype(torch.float64)
答案 1 :(得分:0)
您应该为此torch.set_default_dtype
使用。
确实,使用torch.set_default_tensor_type
也会产生类似的效果,但是torch.set_default_tensor_type
不仅设置默认数据类型,而且还设置设备的默认值张量的分配位置以及张量的布局。