如何将微调的bert模型的输出作为inpunt馈给另一个微调的bert模型?

时间:2020-02-19 10:13:53

标签: pytorch pre-trained-model bert-language-model huggingface-transformers

我在情感分析和pos标记任务上微调了两个单独的bert模型(bert-base-uncase)。现在,我想将pos标记器的输出(批处理,seqlength,hiddensize)作为情感模型的输入。原始的基于bert-base的模型位于``bertModel /''文件夹中,该文件夹包含'model.bin'和' config.json”。这是我的代码:

class DeepSequentialModel(nn.Module):
def __init__(self, sentiment_model_file, postag_model_file, device):
    super(DeepSequentialModel, self).__init__()

    self.sentiment_model = SentimentModel().to(device)
    self.sentiment_model.load_state_dict(torch.load(sentiment_model_file, map_location=device))
    self.postag_model = PosTagModel().to(device)
    self.postag_model.load_state_dict(torch.load(postag_model_file, map_location=device))

    self.classificationLayer = nn.Linear(768, 1)

def forward(self, seq, attn_masks):
    postag_context = self.postag_model(seq, attn_masks)
    sent_context = self.sentiment_model(postag_context, attn_masks)
    logits = self.classificationLayer(sent_context)
    return logits

class PosTagModel(nn.Module):
def __init__(self,):
    super(PosTagModel, self).__init__()
    self.bert_layer = BertModel.from_pretrained('bertModel/')
    self.classificationLayer = nn.Linear(768, 43)

def forward(self, seq, attn_masks):
    cont_reps, _ = self.bert_layer(seq, attention_mask=attn_masks)
    return cont_reps

class SentimentModel(nn.Module):
def __init__(self,):
    super(SentimentModel, self).__init__()
    self.bert_layer = BertModel.from_pretrained('bertModel/')
    self.cls_layer = nn.Linear(768, 1)

def forward(self, input, attn_masks):
    cont_reps, _ = self.bert_layer(encoder_hidden_states=input, encoder_attention_mask=attn_masks)
    cls_rep = cont_reps[:, 0]
    return cls_rep

但是出现以下错误。如果有人可以帮助我,我将不胜感激。谢谢!

    cont_reps, _ = self.bert_layer(encoder_hidden_states=input, encoder_attention_mask=attn_masks)
    result = self.forward(*input, **kwargs)
    TypeError: forward() got an unexpected keyword argument 'encoder_hidden_states'

1 个答案:

答案 0 :(得分:1)

也可以将其表达为答案,并使其对以后的访问者正确可见,为此,请使用does not support these arguments in version 2.1.1的转换器first possible in version 2.2.0或任何更早的版本。请注意,我评论中的链接实际上指向一个不同的转发功能,但除此之外,这一点仍然成立。

forward()传递到encoder_hidden_states是{{3}}。