当网络需要共享的(常量)Tensor时使用`DataParallel`

时间:2020-01-14 10:39:44

标签: gpu pytorch

我想使用DataParallel在批处理维度上跨多个GPU分布我的计算。我的网络内部需要一个Tensor(我们称其为A),它是恒定的,并且在优化过程中不会改变。似乎DataParallel不会自动将此Tensor复制到所有有问题的GPU,因此网络将抱怨它看到的输入数据x的块位于不同的位置GPU超过A

DataParallel是否可以自动处理这种情况?或者,是否可以将Tensor复制到所有 GPU?还是应该为每个GPU保留一个Tensor并根据forward所看到的块所在的位置手动找出要使用的副本?

1 个答案:

答案 0 :(得分:2)

您应该将张量包装在torch.nn.Parameter中,并在创建过程中设置requires_grad=False

torch.nn.Parameter 并不意味着张量必须是可训练的

这仅表示它是模型的一部分,应在需要时进行转移(例如,多个GPU)。

如果不是这种情况,torch将无法知道__init__内部的哪个张量是模型的一部分(您可以对张量进行一些操作并将其添加到{{1} }只是为了完成某项工作。)

我认为不需要其他功能来做到这一点,尽管名称可能会造成一些混乱。