libtorch将ModuleDict转换为C ++

时间:2019-07-23 11:31:06

标签: pytorch libtorch

有人知道如何通过libtorch跟踪将pytorch模型(带有ModuleDict成员变量)转换为C ++吗?

主要问题是其forward函数仅接受一个channel_tensor。为了使整个C ++模型正常工作,我们需要使用所有可能的channel_tensor来跟踪模型。

我可以跟踪一个ID = 1的通道,如下所示,但是如何将channel_id = 2、3、4、5的跟踪模型组合在一起?

channel = torch.ones(1, dtype=torch.int64)
traced_script_module = torch.jit.trace(model, (premise, premise_length, hypotheses, hypotheses_length, channel))

output = traced_script_module(premise, premise_length, hypotheses, hypotheses_length, channel)
traced_script_module.save('deploy-trace-multitask.pt')

模型定义的片段:

    self._word_embedding = nn.Embedding(self.vocab_size,
                                        self.embedding_dim,
                                        padding_idx=padding_idx,
                                        _weight=embeddings)

    if self.dropout:
        self._rnn_dropout = RNNDropout(p=self.dropout) #shared by all tasks
        # self._rnn_dropout = nn.Dropout(p=self.dropout)

    self._encoding = Seq2SeqEncoder(nn.LSTM,
                                    self.embedding_dim,
                                    self.hidden_size,
                                    bidirectional=True)

    #multi-task
    self._attention = nn.ModuleDict({})
    self._projection = nn.ModuleDict({})
    self._classification = nn.ModuleDict({})
    for channel in channels_list:
        self.update(channel)

    # Initialize all weights and biases in the model.
    self.apply(_init_esim_weights)

def update(self, channel):
    channel = str(channel)
    self._attention.update({channel : SoftmaxAttention()})

    self._projection.update({channel : nn.Sequential(nn.Linear(4*2*self.hidden_size, self.hidden_size), nn.ReLU())})

    self._classification.update({channel : nn.Sequential(nn.Dropout(p=self.dropout),
                                         nn.Linear(4*self.hidden_size,
                                                   self.hidden_size),
                                         nn.Tanh(),
                                         nn.Dropout(p=self.dropout),
                                         nn.Linear(self.hidden_size,
                                                   self.num_classes))})

def forward(self,
            premises,
            premises_lengths,
            hypotheses,
            hypotheses_lengths,
            channel_tensor): #must be a tensor
    """
    Args:
        premises: A batch of varaible length sequences of word indices
            representing premises. The batch is assumed to be of size
            (batch, premises_length).
        premises_lengths: A 1D tensor containing the lengths of the
            premises in 'premises'.
        hypothesis: A batch of varaible length sequences of word indices
            representing hypotheses. The batch is assumed to be of size
            (batch, hypotheses_length).
        hypotheses_lengths: A 1D tensor containing the lengths of the
            hypotheses in 'hypotheses'.

    Returns:
        logits: A tensor of size (batch, num_classes) containing the
            logits for each output class of the model.
        probabilities: A tensor of size (batch, num_classes) containing
            the probabilities of each output class in the model.
    """
    channel_id = channel_tensor.item()
    channel = str(channel_id)
    premises_mask = get_mask(premises, premises_lengths).to(self.device)
    hypotheses_mask = get_mask(hypotheses, hypotheses_lengths)\
        .to(self.device)

    embedded_premises = self._word_embedding(premises)
    embedded_hypotheses = self._word_embedding(hypotheses)

    if self.dropout:
        embedded_premises = self._rnn_dropout(embedded_premises)
        embedded_hypotheses = self._rnn_dropout(embedded_hypotheses)

    encoded_premises = self._encoding(embedded_premises,
                                      premises_lengths)
    encoded_hypotheses = self._encoding(embedded_hypotheses,
                                        hypotheses_lengths)

    attended_premises, attended_hypotheses =\
        self._attention[channel](encoded_premises, premises_mask,
                        encoded_hypotheses, hypotheses_mask)
    """ rest of the code are omitted """

0 个答案:

没有答案