Pytorch:如何在多个GPU之间平衡反向传播计算

时间:2019-08-21 04:17:05

标签: pytorch

为了平衡向后计算损失,我不仅要在forward()中计算损失,而且要向后实现损失。这是我的实现方式,有什么不对吗?

 def forward(self, x, y=None, criterion=None, gpu_nums=1):
    x = self.task(x)
    if not y == None:
        loss = criterion(x, y)
        if gpu_nums > 1:
            loss /= gpu_nums
        loss.backward()
        return x, loss
    return x

我的实现是否正确?如果不正确,希望您能提供正确的方法。这对每个人都有意义。

0 个答案:

没有答案