CUDA与DataParallel:为什么不同?

时间:2017-06-16 03:45:01

标签: pytorch

我有一个简单的神经网络模型,我在模型上应用cuda()DataParallel(),如下所示。

model = torch.nn.DataParallel(model).cuda()

OR,

model = model.cuda()

当我不使用DataParallel,而只是简单地将我的模型转换为cuda()时,我需要将批输入显式转换为cuda(),然后将其提供给模型,否则返回以下错误。

  

torch.index_select收到了无效的参数组合 - got(torch.cuda.FloatTensor,int,torch.LongTensor)

但是使用DataParallel,代码运行正常。其他的事情是一样的。为什么会这样?为什么当我使用DataParallel时,我不需要明确地将批输入转换为cuda()

1 个答案:

答案 0 :(得分:7)

因为,DataParallel允许CPU输入,因为它的第一步是将输入传输到适当的GPU。

信息来源:https://discuss.pytorch.org/t/cuda-vs-dataparallel-why-the-difference/4062/3