PyTorch什么时候自动转换Tensor dtype?

时间:2019-01-11 23:18:45

标签: python numpy pytorch

PyTorch何时自动转换Tensor dtype?为什么有时它会自动执行,而有时却抛出错误?

例如,这会自动将c强制转换为浮点数:

a = torch.tensor(5)    
b = torch.tensor(5.)
c = a*b 

a.dtype
>>> torch.int64

b.dtype
>>> torch.float32

c.dtype
>>> torch.float32

但这会引发错误:

a = torch.ones(2, dtype=torch.float)   
b = torch.ones(2, dtype=torch.long)    
c = torch.matmul(a,b)

Traceback (most recent call last):

  File "<ipython-input-128-fbff7a713ff0>", line 1, in <module>
    torch.matmul(a,b)

RuntimeError: Expected object of scalar type Float but got scalar type Long for argument #2 'tensor'

我很困惑,因为Numpy似乎会根据需要自动转换所有数组,例如

a = np.ones(2, dtype=np.long)
b = np.ones(2, dtype=np.float)

np.matmul(a,b)
>>> 2.0

a*b
>>> array([1., 1.])

1 个答案:

答案 0 :(得分:3)

PyTorch团队似乎正在解决这些类型的问题,请参见this issue。按照您的示例,似乎已经在1.0.0中实现了一些基本的向上转换(可能是对于重载的运算符,尝试了诸如“ //”或加法运算的其他运算符,但它们工作正常),但是没有找到任何证明(例如github问题或文档中的信息)。如果有人找到了它(对各种操作隐式转换为torch.Tensor,请发表评论或其他答案。

This issue是有关类型提升的建议,因为您可以看到所有这些内容仍处于打开状态。