预期输入的batch_size(1)要与目标batch_size(11)相匹配

时间:2020-05-09 22:06:53

标签: python tensorflow pytorch classification tensor

我知道这似乎是一个常见问题,但是我找不到解决方案。我正在运行一个多标签分类模型,并且张量大小有问题。

我的完整代码如下:

from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
import torch

# Instantiating tokenizer and model
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-cased')

# Instantiating quantized model
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

# Forming data tensors
input_ids = torch.tensor(tokenizer.encode(x_train[0], add_special_tokens=True)).unsqueeze(0)
labels = torch.tensor(Y[0]).unsqueeze(0)

# Train model
outputs = quantized_model(input_ids, labels=labels)
loss, logits = outputs[:2]

哪个会产生错误:

ValueError: Expected input batch_size (1) to match target batch_size (11)

Input_ids如下:

tensor([[  101,   789,   160,  1766,  1616,  1110,   170,  1205,  7727,  1113,
           170,  2463,  1128,  1336,  1309,  1138,   112,   119, 11882, 11545,
           119,   108, 15710,   108,  3645,   108,  3994,   102]])

形状:

torch.Size([1, 28])

和标签如下:

tensor([[0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1]])

形状:

torch.Size([1, 11])

input_ids的大小将随着要编码的字符串大小的变化而变化。

我还注意到,当输入5个Y值以产生5个标签时,会产生错误:

ValueError: Expected input batch_size (1) to match target batch_size (55).

带有标签形状:

torch.Size([1, 5, 11])

(请注意,我没有输入5个input_id,这大概就是为什么输入大小保持不变的原因)

我已经尝试了几种不同的方法来使它们起作用,但是我现在很茫然。我真的很感谢一些指导。谢谢!

1 个答案:

答案 0 :(得分:2)

DistilBertForSequenceClassification的标签必须具有文档中提到的大小torch.Size([batch_size])

  • 标签(形状为torch.LongTensor的{​​{1}},可选,默认为(batch_size,))–用于计算序列分类的标签/回归损失。索引应位于None中。如果[0, ..., config.num_labels - 1]计算出回归损失(均方根损失),如果config.num_labels == 1计算出分类损失(十字熵)。

在您的情况下,您的config.num_labels > 1的大小应为labels

这对于您的数据是不可能的,这是因为序列分类为每个序列都有一个标签,但是您希望将其设为多标签分类。

据我所知,HuggingFace的转换器库中没有可立即使用的多标签模型。您将需要创建自己的模型,这并不是特别困难,因为这些额外的模型都使用相同的基本模型,并根据要解决的任务在最后添加一个适当的分类器。 HuggingFace - Multi-label Text Classification using BERT – The Mighty Transformer解释了如何做到这一点。