如何将微调的拥抱面变压器模型分为两个网络?

时间:2020-06-30 13:15:40

标签: python pytorch huggingface-transformers

我正在一个具有多个文本分类任务的框架上工作(如二进制文本分类示例中所示)。对于这些任务,我使用功能强大的huggingface-transformer library。对于二进制分类任务,我的代码实际上看起来像这样。

class BertForBinaryDocumentClassification(BertPreTrainedModel):

def __init__(self, config):
    super().__init__(config)
    self.bert = BertModel.from_pretrained('bert-base-uncased')
    self.dropout = nn.Dropout(0.1)
    self.classifier = nn.Linear(in_features=config.hidden_size, out_features=config.num_labels)
    self.init_weights()

def forward(self, input_ids, token_type_ids=None, attention_mask=None,head_mask=None, labels=None):
    outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, head_mask=head_mask)
    pooled_output = outputs[0]
    pooled_output = self.dropout(pooled_output)
    logits = self.classifier(pooled_output)

    outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

    loss_fct = BCELoss()
    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
    outputs = (loss,) + outputs

    return outputs  # (loss), logits, (hidden_states), (attentions)

现在,我遇到了性能问题,因为对于每个分类任务,我都使用带有适当分类头(二进制和多分类)的整个BERT模型。一种型号约为400-500Mb。

现在,我想知道是否可以在BERT模型中拆分这些网络,该模型生成句子嵌入并将这些编码的句子嵌入发送到具有最终线性层的独立分类网络,如下面的代码所示。

现在的问题是,我不确定如何解决此问题。我正在阅读冻结最后一层并将中间嵌入发送到分类网络。一件事是我想知道它是否仍代表从1到11的BERT层的全部信息?

另一件事是,我不太确定独立分类层的外观。我认为它应该看起来像这样。

class BinaryClassifier(nn.Module):

def __init__(self, vocab_size, embedding_dim, context_size):
    super(BinaryClassifier, self).__init__()
    self.classifier = nn.Linear(in_features=config.hidden_size, out_features=config.num_labels)
    self.init_weights()

def forward(self, inputs):
    embeds = self.embeddings(inputs).view((1, -1))
    logits = self.classifier(embeds)
    outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

    loss_fct = BCELoss()
    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
    outputs = (loss,) + outputs

    return outputs  # (loss), logits, (hidden_states), (attentions)

有人可以告诉我我的方法是否正确,以及如何解决此任务吗?预先感谢。

0 个答案:

没有答案