使用Multi-gpu时成员变量的Pytorch错误值

时间:2018-10-30 09:44:55

标签: variables pytorch multi-gpu

这是在多GPU环境中运行的简单类。成员变量self.firstIter在第一次迭代后应为False

Class TestNetwork(nn.Module):

    def __init__(self):
        super(TestNetwork, self).__init__()
        self.firstIter = True #indicates whether it's the first iteration

    def forward(self, input):
        print 'is firstIter: ', self.firstIter #always True!!
        if self.firstIter is True:
            self.firstIter = False
        # do otherthings

仅使用一个GPU时,代码将按预期工作。

但是,在使用multi-gpu(即nn.DataParallel)时,self.firstIter的值始终打印为True

为什么会这样?代码有什么问题?

使用PyTorch版本0.3.1。

1 个答案:

答案 0 :(得分:-2)

基本上,DataParallel在模型副本上运行,并且如果didiv的数量大于1,则在向前/向后调用之外看不到对副本所做的更改(前进期间)。

请详细参考https://discuss.pytorch.org/t/nonetype-attribute-when-using-dataparallel/11566