我有一台带有两个GPU的服务器,如果我使用一个GPU,我需要超过10天才能完成1000个纪元。但是当我尝试使用Dataparallel时,该程序无效。似乎因为我的网络中存在“for”循环。那么在这种情况下如何使用Dataparallel。或者还有其他解决方案来加速培训吗?
class WaveNet( nn.Module ):
def __init__(self, mu, n_residue, n_skip, dilation_depth, n_repeat):
# mu: audio quantization size
# n_residue: residue channels
# n_skip: skip channels
# dilation_depth & n_repeat: dilation layer setup
self.mu = mu
super( WaveNet, self ).__init__()
self.dilation_depth = dilation_depth
dilations = self.dilations = [2 ** i for i in range( dilation_depth )] * n_repeat
self.one_hot = One_Hot( mu )
self.from_input = nn.Conv1d( in_channels=mu, out_channels=n_residue, kernel_size=1 )
self.from_input = nn.DataParallel(self.from_input)
self.conv_sigmoid = nn.ModuleList(
[nn.Conv1d( in_channels=n_residue, out_channels=n_residue, kernel_size=2, dilation=d )
for d in dilations] )
self.conv_sigmoid = nn.DataParallel(self.conv_sigmoid)
self.conv_tanh = nn.ModuleList(
[nn.Conv1d( in_channels=n_residue, out_channels=n_residue, kernel_size=2, dilation=d )
for d in dilations] )
self.conv_tanh = nn.DataParallel(self.conv_tanh)
self.skip_scale = nn.ModuleList( [nn.Conv1d( in_channels=n_residue, out_channels=n_skip, kernel_size=1 )
for d in dilations] )
self.skip_scale = nn.DataParallel(self.skip_scale)
self.residue_scale = nn.ModuleList( [nn.Conv1d( in_channels=n_residue, out_channels=n_residue, kernel_size=1 )
for d in dilations] )
self.residue_scale = nn.DataParallel(self.residue_scale)
self.conv_post_1 = nn.Conv1d( in_channels=n_skip, out_channels=n_skip, kernel_size=1 )
self.conv_post_1 = nn.DataParallel(self.conv_post_1)
self.conv_post_2 = nn.Conv1d( in_channels=n_skip, out_channels=mu, kernel_size=1 )
self.conv_post_2 = nn.DataParallel(self.conv_post_2)
def forward(self, input, train=True):
output = self.preprocess( input, train )
skip_connections = [] # save for generation purposes
for s, t, skip_scale, residue_scale in zip( self.conv_sigmoid, self.conv_tanh, self.skip_scale,
self.residue_scale ):
output, skip = self.residue_forward( output, s, t, skip_scale, residue_scale )
skip_connections.append( skip )
# sum up skip connections
output = sum( [s[:, :, -output.size( 2 ):] for s in skip_connections] )
output = self.postprocess( output, train )
return output
TypeError: zip argument #1 must support iteration