在PyTorch中设置浮点类型时,张量类型和dtype有什么区别,何时应该在另一个之上设置一个?

时间:2019-04-04 17:40:57

标签: numpy runtime-error pytorch

我正在使用double作为模型的输入和输出,因此我试图将手电筒设置为使用float64而不是float32。到底有什么区别

torch.set_default_tensor_type(torch.DoubleTensor)

  

设置默认的割炬。张量类型为浮点张量类型t。此类型也将用作torch.tensor()中类型推断的默认浮点类型。

torch.set_default_dtype(torch.float64)

  

将默认浮点dtype设置为d。此类型将用作torch.tensor()中类型推断的默认浮点类型。

文档告诉我,设置张量类型也会设置dtype,但是我不确定何时将一个使用另一个。

我应该提到,这两个语句都可以解决从浮点数变为双精度后出现的错误:

Traceback (most recent call last):
  File "train.py", line 122, in train_model  
    output = net(action)  
  File "/opt/anaconda3/lib/python3.7/site- packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "models.py", line 25, in forward
    return self.fc2(F.relu(self.fc1(x)))
  File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 67, in forward
    return F.linear(input, self.weight, self.bias)
  File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 1352, in linear
    ret = torch.addmm(torch.jit._unwrap_optional(bias), input, weight.t())
RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #4 'mat1'

0 个答案:

没有答案