关于torch.nn.DataParallel的问题

时间:2018-10-05 10:16:20

标签: python pytorch

我是深度学习领域的新手。现在,我正在复制论文的代码。由于它们使用多个GPU,因此代码中有一个命令torch.nn.DataParallel(model, device_ids= args.gpus).cuda()。但是我只有一个GPU 我应该更改此代码以匹配我的GPU吗?

谢谢!

1 个答案:

答案 0 :(得分:3)

DataParallel也应该在单个GPU上工作,但是您应该检查args.gpus是否仅包含要使用的设备的ID(应为0)或None 。 选择None将使该模块使用所有可用的设备。

此外,您可以删除DataParallel,因为您不需要它,而只需通过调用model.cuda()或(如我更喜欢的model.to(device)即可将模型移至GPU,其中device是设备的名称。

示例:

此示例说明如何在单个GPU上使用模型,并使用.to()而非.cuda()设置设备。

from torch import nn
import torch

# Set device to cuda if cuda is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create model
model = nn.Sequential(
  nn.Conv2d(1,20,5),
  nn.ReLU(),
  nn.Conv2d(20,64,5),
  nn.ReLU()
)

# moving model to GPU
model.to(device)

如果您想使用DataParallel,可以这样做

# Optional DataParallel, not needed for single GPU usage
model1 = torch.nn.DataParallel(model, device_ids=[0]).to(device)
# Or, using default 'device_ids=None'
model1 = torch.nn.DataParallel(model).to(device)