预期输入的batch_size(32)匹配目标batch_size(19840)BERT分类器

时间:2020-07-09 21:29:51

标签: python deep-learning nlp pytorch

我在代码中遇到此错误:

movies

形状看起来不错,但出现错误model = BertForSequenceClassification.from_pretrained("pretrained/", num_labels=ohe_count) model.to(device) from IPython.display import clear_output train_loss_set = [] train_loss = 0 model.train() for step, batch in enumerate(train_dataloader): # добавляем батч для вычисления на GPU batch = tuple(t.to(device) for t in batch) # Распаковываем данные из dataloader b_input_ids, b_input_mask, b_labels = batch b_input_ids = b_input_ids.type(torch.LongTensor) b_input_mask = b_input_mask.type(torch.LongTensor) b_labels = b_labels.type(torch.LongTensor) b_input_ids = b_input_ids.to(device) b_input_mask = b_input_mask.to(device) b_labels = b_labels.to(device) optimizer.zero_grad() # Forward pass loss = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels) train_loss_set.append(loss[0].item()) # Backward pass loss[0].backward() optimizer.step() train_loss += loss[0].item() clear_output(True) plt.plot(train_loss_set) plt.title("Training loss") plt.xlabel("Batch") plt.ylabel("Loss") plt.show() b_input_ids.shape = torch.Size([32, 100]) b_labels.shape = torch.Size([32, 620])

1 个答案:

答案 0 :(得分:0)

没关系,我试图进行多标签分类,但是BertForSequenceClassification无法做到这一点。