DataParallel使用for循环会发生意外行为

时间:2019-06-02 01:56:34

标签: pytorch

我在自定义的PyTorch K中有ModuleList个由nn.Module包装的自动编码器。当我尝试使用DataParallel在多个GPU上并行化模型时,我用来训练for-loop中每个自动编码器的forward主体执行了几次,这会导致严重的问题,因为我将附加到循环内的列表中,因此,具有比我期望的更多的元素。存根片段可以是:

class CustomModel(nn.Module):

    def __init__(self,**kwargs):
        super(CustomModel, self).__init__()
        self.autoencoders = nn.ModuleList([AutoEncoder(**kwargs) for i in range(K)])

        self.encoders_output = []
        self.decoders_output = []

    def forward(self, input, input_length):
        for i, autoencoder in enumerate(self.autoencoders):
            decoder_output, encoder_output = autoencoder(input, input_length)
            # debug(i)
            self.encoders_output.append(encoder_output)
            self.decoders_output.append(decoder_output)
        self.encoders_output = []
        self.decoders_output = []
        return output

model = CustomModel(**kwargs)
model = nn.DataParallel(model).cuda()

如果您尝试取消注释debug(i),则会看到每个数字都已调试(即打印)多次。

0 个答案:

没有答案