pytorch的并行方法和分布式方法如何工作?

时间:2018-11-19 13:13:17

标签: c++ python-3.x parallel-processing distributed-computing pytorch

我不是分布式系统和CUDA方面的专家。但是PyTorch支持的一个非常有趣的功能是nn.DataParallelnn.DistributedDataParallel。它们如何实际实施?他们如何分隔常见的嵌入并同步数据?

这是DataParallel的基本示例。

import torch.nn as nn
from torch.autograd.variable import Variable
import numpy as np

class Model(nn.Module):
    def __init__(self):
        super().__init__(
            embedding=nn.Embedding(1000, 10),
            rnn=nn.Linear(10, 10),
        )

    def forward(self, x):
        x = self.embedding(x)
        x = self.rnn(x)
        return x

model = nn.DataParallel(Model())
model.forward(Variable.from_numpy(np.array([1,2,3,4,5,6], dtype=np.int64)).cuda()).cpu()

PyTorch可以拆分输入并将其发送到许多GPU,然后将结果合并回去。

它如何管理并行模型或分布式模型的嵌入和同步?
我徘徊在PyTorch的代码上,但是很难知道基本原理如何工作。

2 个答案:

答案 0 :(得分:1)

这是一个很好的问题。
PyTorch DataParallel 范式实际上非常简单,并且实现是开源的 here 。请注意,今天不推荐他的范式,因为它在主 GPU 上存在瓶颈,并且数据传输效率不高。

<块引用>

此容器通过以下方式并行化给定 :attr:module 的应用程序 通过在批处理中分块将输入拆分到指定的设备 维度(其他对象将在每个设备上复制一次)。在前锋 通过,模块被复制到每个设备上,每个副本处理一个 输入的一部分。在向后传递期间,每个副本的梯度 合并到原始模块中。

从 DistributedDataParallel 开始,那就更棘手了。这是目前更高级的方法,而且非常有效(请参阅 here)。

<块引用>

此容器通过以下方式并行化给定模块的应用程序 通过在批处理中分块将输入拆分到指定的设备 尺寸。该模块在每台机器和每台设备上复制,并且 每个这样的副本处理输入的一部分。倒退期间 通过,每个节点的梯度被平均。

对于如何平均每个节点的梯度有几种方法。我会推荐 this 论文来真正了解事情是如何运作的。一般来说,在将数据从一个 GPU 传输到另一个 GPU 之间需要权衡带宽和速度,我们希望这部分真正高效。因此,一种可能的方法是将每对 GPU 与一个非常快速的协议连接成一个圆圈,并且仅将部分梯度从一个传递到另一个,s.t.总的来说,我们传输的数据更少,效率更高,并且所有节点都获得了所有梯度(或至少是它们的平均值)。在那种情况下仍然会有一个主 GPU,或者至少是一个进程,但现在任何 GPU 上都没有瓶颈,它们都共享相同数量的数据(最多......)。

现在可以进一步优化,如果我们不等待所有批次完成计算并开始做一个分时的事情,每个节点在他准备好时发送他的部分。细节不用我说,但事实证明,如果我们不等一切都结束,尽快做平均,也可能会加快梯度平均。

有关该领域的更多信息,请参阅文献,因为它仍在发展中(截至今天)。

PS 1:通常这些分布式训练在为该任务设置的机器上效果更好,例如在硬件中实现这些协议的 AWS 深度学习实例。

PS 2:免责声明:我真的不知道 PyTorch 开发人员选择实施什么协议,以及根据什么选择了什么。我使用分布式培训,更喜欢遵循 ​​PyTorch 最佳实践,而不是试图超越它们。我建议你也这样做,除非你真的很想研究这个领域。

参考文献:

[1] Distributed Training of Deep Learning Models: A Taxonomic Perspective

答案 1 :(得分:-2)

据我所知,该代码是在parallel_apply.py中实现的

[编辑:在此处粘贴代码以方便参考]

def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
    r"""Applies each `module` in :attr:`modules` in parallel on arguments
    contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
    on each of :attr:`devices`.
    Args:
        modules (Module): modules to be parallelized
        inputs (tensor): inputs to the modules
        devices (list of int or torch.device): CUDA devices
    :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
    :attr:`devices` (if given) should all have same length. Moreover, each
    element of :attr:`inputs` can either be a single object as the only argument
    to a module, or a collection of positional arguments.
    """
    assert len(modules) == len(inputs)
    if kwargs_tup is not None:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({},) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)
    devices = list(map(lambda x: _get_device_index(x, True), devices))
    lock = threading.Lock()
    results = {}
    grad_enabled = torch.is_grad_enabled()

    def _worker(i, module, input, kwargs, device=None):
        torch.set_grad_enabled(grad_enabled)
        if device is None:
            device = get_a_var(input).get_device()
        try:
            with torch.cuda.device(device):
                # this also avoids accidental slicing of `input` if it is a Tensor
                if not isinstance(input, (list, tuple)):
                    input = (input,)
                output = module(*input, **kwargs)
            with lock:
                results[i] = output
        except Exception as e:
            with lock:
                results[i] = e

    if len(modules) > 1:
        threads = [threading.Thread(target=_worker,
                                    args=(i, module, input, kwargs, device))
                   for i, (module, input, kwargs, device) in
                   enumerate(zip(modules, inputs, kwargs_tup, devices))]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, Exception):
            raise output
        outputs.append(output)
    return outputs
  • modules是要并行化的模块。
  • inputs是模块的张量
  • devices是CUDA设备
  • resultsoutput存储最终结果
  • _worker()是线程运行的主要功能