在pytorch数据并行模式下,如何使用全局张量?

时间:2019-04-22 10:20:43

标签: mpi pytorch

在此示例中,我希望z_proto对于不同的GPU可以是全局的。但是,在数据并行模式下,它也分为不同的GPU。如何解决这样的问题?谢谢。

class SequencePrototypeTokenClassification(nn.Module):
    def __init__(self,seq_model, label_num):
        super(SequencePrototypeTokenClassification, self).__init__()
        self.seq_model = seq_model
        self.label_num = label_num

    def forward(self, input_ids, token_type_ids, attention_mask, labels, z_proto, n_query, target_inds):
        z, _ = self.seq_model(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
        z_dim = z.size(-1)
        zq = z.squeeze().view(-1, z_dim)
        dists = euclidean_dist(zq, z_proto)
        log_p_y = F.log_softmax(-dists, dim=1).view(-1, self.label_num)
        loss_val = -log_p_y.gather(1, self.target_inds).squeeze().view(-1).mean()
        _, y_hat = log_p_y.max(1)

        return loss_val, y_hat

2 个答案:

答案 0 :(得分:1)

根据您的上述代码,z_proto似乎是转发函数的参数之一,而不是模型的一部分。因此,只需将其存储在主GPU上的tensor中,即可使其在GPU之间具有相同的值。

编辑

基于the documentation,看来DataParallel将所有输入分配给各个GPU的正向传递函数。您可以通过将其作为类变量存储在模型对象本身中的方法来规避它。如果它不是静态变量,则可以在调用forward函数之前更新该值。

class SequencePrototypeTokenClassification(nn.Module):
    def __init__(self,seq_model, label_num):
        ...
        self.z_proto = None
        ...
        ...


#Training loop
    ...
    model.z_proto = value
    model.forward()
    ...


答案 1 :(得分:0)

事实证明,DataParallel仅复制nn.Parameter中的nn.Module。因此,我在模块中随机初始化了一个名为nn.Parameter的{​​{1}}并将张量z_proto的值复制到该参数中。然后将参数复制到4个GPU中。