为什么我的简单pytorch网络不能在GPU设备上工作?

时间:2018-07-31 05:18:23

标签: python image-processing machine-learning deep-learning pytorch

我从教程中构建了一个简单的网络,但出现此错误:

  

RuntimeError:类型为torch.cuda.FloatTensor的预期对象,但找到了   为参数#4'mat1'输入torch.FloatTensor

有帮助吗?谢谢!

12.2.1.2.0

2 个答案:

答案 0 :(得分:5)

TL; DR
这就是解决方法

inputs = inputs.to(device)  

为什么?!
torch.nn.Module.to()torch.Tensor.to()之间存在细微差别:Module.to()就地运算符,而Tensor.to()不是。因此

net.to(device)

更改net本身并将其移动到device。另一方面

inputs.to(device)

不会更改inputs,而是返回位于inputs上的device副本。要使用该“在设备上”副本,您需要将其分配给一个变量,因此

inputs = inputs.to(device)

答案 1 :(得分:0)

import torch
import numpy as np

x = torch.tensor(np.array(1), device='cuda:0')

print(x.device)  # Prints `cpu`

x = torch.tensor(1, device='cuda:0')

print(x.device)  # Prints `cuda:0`

现在,张量位于GPU上