当网络中存在“for”周期时,如何使用Dataparallel?

时间:2018-03-25 06:27:43

标签: deep-learning pytorch

我有一台带有两个GPU的服务器,如果我使用一个GPU,我需要超过10天才能完成1000个纪元。但是当我尝试使用Dataparallel时,该程序无效。似乎因为我的网络中存在“for”循环。那么在这种情况下如何使用Dataparallel。或者还有其他解决方案来加速培训吗?

class WaveNet( nn.Module ):    
    def __init__(self, mu, n_residue, n_skip, dilation_depth, n_repeat):
        # mu: audio quantization size
        # n_residue: residue channels
        # n_skip: skip channels
        # dilation_depth & n_repeat: dilation layer setup
        self.mu = mu
        super( WaveNet, self ).__init__()
        self.dilation_depth = dilation_depth
        dilations = self.dilations = [2 ** i for i in range( dilation_depth )] * n_repeat
        self.one_hot = One_Hot( mu )
        self.from_input = nn.Conv1d( in_channels=mu, out_channels=n_residue, kernel_size=1 )
        self.from_input = nn.DataParallel(self.from_input)
        self.conv_sigmoid = nn.ModuleList(
            [nn.Conv1d( in_channels=n_residue, out_channels=n_residue, kernel_size=2, dilation=d )
             for d in dilations] )
        self.conv_sigmoid = nn.DataParallel(self.conv_sigmoid)
        self.conv_tanh = nn.ModuleList(
            [nn.Conv1d( in_channels=n_residue, out_channels=n_residue, kernel_size=2, dilation=d )
             for d in dilations] )
        self.conv_tanh = nn.DataParallel(self.conv_tanh)
        self.skip_scale = nn.ModuleList( [nn.Conv1d( in_channels=n_residue, out_channels=n_skip, kernel_size=1 )
                                          for d in dilations] )
        self.skip_scale = nn.DataParallel(self.skip_scale)
        self.residue_scale = nn.ModuleList( [nn.Conv1d( in_channels=n_residue, out_channels=n_residue, kernel_size=1 )
                                             for d in dilations] )
        self.residue_scale = nn.DataParallel(self.residue_scale)
        self.conv_post_1 = nn.Conv1d( in_channels=n_skip, out_channels=n_skip, kernel_size=1 )
        self.conv_post_1 = nn.DataParallel(self.conv_post_1)
        self.conv_post_2 = nn.Conv1d( in_channels=n_skip, out_channels=mu, kernel_size=1 )
        self.conv_post_2 = nn.DataParallel(self.conv_post_2)

    def forward(self, input, train=True):
        output = self.preprocess( input, train )
        skip_connections = []  # save for generation purposes
        for s, t, skip_scale, residue_scale in zip( self.conv_sigmoid, self.conv_tanh, self.skip_scale,
                                                    self.residue_scale ):
            output, skip = self.residue_forward( output, s, t, skip_scale, residue_scale )
            skip_connections.append( skip )
        # sum up skip connections
        output = sum( [s[:, :, -output.size( 2 ):] for s in skip_connections] )
        output = self.postprocess( output, train )
        return output


TypeError: zip argument #1 must support iteration

0 个答案:

没有答案