ValueError:预期输入batch_size(1)匹配目标batch_size(26)

时间:2020-06-06 06:58:27

标签: pytorch lstm transformer

将BERT与BiLSTM一起使用时出现以下错误(我在BERT上的batch_size为26)。我想连接BERT的最后4个隐藏层,然后将其提供给BiLSTM。这是我的模型:

from transformers import BertPreTrainedModel, BertModel
import torch.nn as nn
import torch
import torch.nn.functional as F

class BERT(BertPreTrainedModel):
    def __init__(self, config):
        super(BERT, self).__init__(config)
        self.device = config.device
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(0.1)

        self.lstm = nn.LSTM(input_size=config.hidden_size * 4, hidden_size=500, num_layers=3, dropout=0.5, bidirectional=True)
        self.qa_outputs = nn.Linear(500*2, config.num_labels)

        self.weight_class = config.weight_class
        self.init_weights()

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        with torch.no_grad():
            outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
            cls_output = torch.cat((outputs[2][-4][:,0, ...],outputs[2][-3][:,0, ...], outputs[2][-2][:,0, ...], outputs[2][-1][:,0, ...]),-1)
            cls_output = self.lstm(cls_output.unsqueeze(0))[0]
            logits = self.qa_outputs(cls_output)
            return logits

    def loss(self, input_ids, attention_mask, token_type_ids, label):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        cls_output = torch.cat((outputs[2][-4][:,0, ...],outputs[2][-3][:,0, ...], outputs[2][-2][:,0, ...], outputs[2][-1][:,0, ...]),-1)
        cls_output = self.lstm(cls_output.unsqueeze(0))[0]
        logits = self.qa_outputs(cls_output)

        target = label
        class_weights = torch.FloatTensor(self.weight_class).to(self.device)
        loss = F.cross_entropy(logits, target, weight=class_weights)

        predict_value = torch.max(logits, 1)[1]
        list_predict = predict_value.cpu().numpy().tolist()
        list_target = target.cpu().numpy().tolist()

        return loss, list_predict, list_target

真的不知道如何调试它。此错误的任何解决方案。预先感谢。

0 个答案:

没有答案