我在3个GPU上使用DataParallel运行我的NN模型。在一个GPU中,内存使用量上升到12gb,因此程序在发出内存不足错误后停止。虽然一个GPU使用大量内存(~12gb),但另外两个GPU内存使用率相当低(~2-3gb)。
在PyTorch中使用DataParallel时,有什么办法可以确保GPU内存使用率是平衡的吗?
编辑:我正在进行神经机器翻译(NMT),我正在分享我使用DataParallel的部分代码。
class NMT(nn.Module):
"""A sequence-to-sequence model for machine translation."""
def __init__(self, dictionary, embedding_index, args):
super(NMT, self).__init__()
self.config = args
self.src_embedding = EmbeddingLayer(len(dictionary[0]), False, self.config)
self.tgt_embedding = EmbeddingLayer(len(dictionary[1]), True, self.config)
if embedding_index is not None:
if isinstance(embedding_index, tuple):
self.src_embedding.init_embedding_weights(dictionary[0], embedding_index[0], self.config.emsize)
self.tgt_embedding.init_embedding_weights(dictionary[1], embedding_index[1], self.config.emsize)
else:
self.src_embedding.init_embedding_weights(dictionary[0], embedding_index, self.config.emsize)
self.encoder_decoder = Encoder_Decoder(args)
if torch.cuda.device_count() > 1:
self.encoder_decoder = torch.nn.DataParallel(self.encoder_decoder)
# word decoding layer
self.out = nn.Linear(self.config.emsize, len(dictionary[1]))
# tie target embedding weights with decoder prediction layer weights
self.tgt_embedding.embedding.weight = self.out.weight
def forward(self, s1, s1_len, s2, s2_len):
"""
Forward computational step of sequence-to-sequence to machine translation.
:param s1: source sentences [batch_size x max_s1_length]
:param s1_len: source sentences' length [batch_size]
:param s2: target sentences [batch_size x max_s2_length]
:param s2_len: target sentences' length [batch_size]
:return: decoding loss [batch_size]
"""
# embedded_s1 = batch_size x max_s1_length x em_size
embedded_s1 = self.src_embedding(s1)
# embedded_s2 = batch_size x max_s2_length x em_size
embedded_s2 = self.tgt_embedding(s2)
# decoder_out: batch_size x max_s2_length x em_size
decoder_out = self.encoder_decoder(embedded_s1, s1_len, embedded_s2)
predictions = f.log_softmax(self.out(decoder_out.view(-1, decoder_out.size(2))), 1)
predictions = predictions.view(*decoder_out.size()[:-1], -1)
decoding_loss, total_local_decoding_loss_element = 0, 0
for idx in range(s2.size(1) - 1):
local_loss, num_local_loss = self.compute_decoding_loss(predictions[:, idx, :], s2[:, idx + 1], idx, s2_len)
decoding_loss += local_loss
total_local_decoding_loss_element += num_local_loss
if total_local_decoding_loss_element > 0:
decoding_loss = decoding_loss / total_local_decoding_loss_element
return decoding_loss
在这里,我特意使用DataParallel。
if torch.cuda.device_count() > 1:
self.encoder_decoder = torch.nn.DataParallel(self.encoder_decoder)
下图描述了该模型的高级概述。
我特别将target_embedding
图层和next_word_prediction
图层保留在DataParallel之外,因为图层与大量参数(24M)相关联并且它们是绑定的。
此外,我没有训练source_embedding
图层,权重被冻结。按照此设置,我观察到运行时间有所改善,但数据量很小。当我尝试使用大型数据集时,它会给我一个内存不足的错误。