使pytorch代码与在CPU或GPU上运行无关的更好方法?

时间:2018-10-02 17:24:58

标签: python gpu cpu pytorch

迁移guide建议采取以下措施,使代码与CPU / GPU无关:

{ 
 "requestTime": 1, 
 "clients": [{
   "id": 905, 
   "name": "Peter" 
  }] 
 }

我这样做了,并在仅CPU的设备上运行了我的代码,但是当输入一个输入数组时,我的模型崩溃了,原因是它期望的是CPU张量而不是GPU张量。我的模型以某种方式自动将CPU输入阵列转换为GPU阵列。最后,我在代码中将其追溯到该命令:

> # at beginning of the script
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
...
# then whenever you get a new Tensor or Module
# this won't copy if they are already on the desired device
input = data.to(device)
model = MyModule(...).to(device)

即使我将模型转换为“ cpu”,但nn.DataParallel仍将其覆盖。我想出的最佳解决方案是有条件的:

model = torch.nn.DataParallel(model).to(device)

这看起来并不优雅。有更好的方法吗?

1 个答案:

答案 0 :(得分:0)

怎么样

if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)
model = model.to(device)

如果只有一个GPU,则不需要DataParallel