ValueError:无法创建张量

时间:2021-05-10 19:46:45

标签: python tensorflow sentiment-analysis bert-language-model

我正在尝试使用 BertModel 进行情绪分析,但是当我尝试训练我的模型时出现此错误“ValueError:无法创建张量,您可能应该使用 'padding=True' 激活截断和/或填充” truncation=True' 以具有相同长度的批处理张量。”有人可以帮我修复它吗?谢谢。这是我的代码:

!pip --quiet install transformers
!pip --quiet install datasets
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from transformers import TrainingArguments
from transformers import Trainer
from datasets import load_dataset

# Set model, dataset and hyperparameters
MODEL_NAME = 'NLP/bert-base-cased-v1'
DATASET = ('xed_en_fi', 'fi_annotated')
# Hyperparameters
LEARNING_RATE=1e-5
BATCH_SIZE=32
TRAIN_EPOCHS=1

dataset = load_dataset(*DATASET)

num_labels= len(set([val for sublist in dataset['train']['labels'] for val in sublist]))
dataset['train'] = dataset['train'].filter(lambda example, idx: idx % 10 == 0, with_indices=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def encode_dataset(d):
  return tokenizer.encode_plus(
  d['sentence'],
  max_length=512,
  add_special_tokens=True, # Add '[CLS]' and '[SEP]'
  return_token_type_ids=False,
  truncation=True,
  padding='max_length',
  return_attention_mask=True,
  return_tensors='pt',  # Return PyTorch tensors
)
encoded_dataset = dataset.map(encode_dataset)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=num_labels)

train_args = TrainingArguments(
    'output_dir',    # output directory for checkpoints and predictions
    save_strategy='no',
    evaluation_strategy='epoch',
    logging_strategy='epoch',
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    num_train_epochs=TRAIN_EPOCHS)

def compute_accuracy(pred):
    y_pred = pred.predictions.argmax(axis=1)
    y_true = pred.label_ids
    return { 'accuracy': sum(y_pred == y_true) / len(y_true) }

trainer = Trainer(
      model,
      train_args,
      train_dataset=encoded_dataset['train'],
      tokenizer=tokenizer,
      compute_metrics=compute_accuracy
)

traine.train()

0 个答案:

没有答案