PyTorch何时自动转换Tensor dtype?为什么有时它会自动执行,而有时却抛出错误?
例如,这会自动将c
强制转换为浮点数:
a = torch.tensor(5)
b = torch.tensor(5.)
c = a*b
a.dtype
>>> torch.int64
b.dtype
>>> torch.float32
c.dtype
>>> torch.float32
但这会引发错误:
a = torch.ones(2, dtype=torch.float)
b = torch.ones(2, dtype=torch.long)
c = torch.matmul(a,b)
Traceback (most recent call last):
File "<ipython-input-128-fbff7a713ff0>", line 1, in <module>
torch.matmul(a,b)
RuntimeError: Expected object of scalar type Float but got scalar type Long for argument #2 'tensor'
我很困惑,因为Numpy似乎会根据需要自动转换所有数组,例如
a = np.ones(2, dtype=np.long)
b = np.ones(2, dtype=np.float)
np.matmul(a,b)
>>> 2.0
a*b
>>> array([1., 1.])
答案 0 :(得分:3)
PyTorch团队似乎正在解决这些类型的问题,请参见this issue。按照您的示例,似乎已经在1.0.0中实现了一些基本的向上转换(可能是对于重载的运算符,尝试了诸如“ //”或加法运算的其他运算符,但它们工作正常),但是没有找到任何证明(例如github问题或文档中的信息)。如果有人找到了它(对各种操作隐式转换为torch.Tensor
,请发表评论或其他答案。
This issue是有关类型提升的建议,因为您可以看到所有这些内容仍处于打开状态。