使用huggingface的distilbert模型生成文本

时间:2019-12-08 22:56:52

标签: machine-learning nlp pytorch transform distilbert

一段时间以来一直在努力拥抱Face的DistilBert模型,因为文档似乎非常不清楚,并且它们的示例(例如https://github.com/huggingface/transformers/blob/master/notebooks/Comparing-TF-and-PT-models-MLM-NSP.ipynbhttps://github.com/huggingface/transformers/tree/master/examples/distillation)非常厚,而他们展示的东西却没有似乎没有充分的记录。

我想知道这里是否有人有任何经验,并且知道一些关于模型基本使用Python的良好代码示例。即:

  • 如何正确地将模型的输出解码为实际文本(无论我如何改变其形状,令牌化程序似乎都愿意对其进行解码,并始终产生[UNK]令牌的序列)

  • 如何实际使用其调度程序和优化程序来训练简单文本到文本任务的模型。

1 个答案:

答案 0 :(得分:1)

要解码输出,你可以这样做

        prediction_as_text = tokenizer.decode(output_ids, skip_special_tokens=True)

output_ids 包含生成的令牌 ID。它也可以是一个批处理(每行输出 id),那么 prediction_as_text 也将是一个包含每行文本的二维数组。 skip_special_tokens=True 过滤掉训练中使用的特殊标记,例如(句子结尾)、(句子开头)等。这些特殊标记当然因模型而异,但几乎每个模型都有在训练过程中使用的特殊标记训练和推理。

没有一种简单的方法可以摆脱未知令牌[UNK]。这些模型的词汇量有限。如果模型遇到不在其词汇表中的子词,则将其替换为特殊的未知标记,并使用这些标记训练模型。所以,它也学会了生成[UNK]。有多种方法可以处理它,例如用第二高的可能标记替换它,或者使用波束搜索并采用不包含任何未知标记的最可能的句子。但是,如果您真的想摆脱这些,您应该使用使用字节对编码的模型。彻底解决生词问题。正如您在此链接中所读到的,Bert 和 DistilBert 使用子工作标记化并有这样的限制。 https://huggingface.co/transformers/tokenizer_summary.html

要使用调度程序和优化程序,您应该使用 TrainerTrainingArguments 类。下面我发布了一个来自我自己的项目的示例。

    output_dir=model_directory,
    num_train_epochs=args.epochs,
    per_device_train_batch_size=args.batch_size,
    per_device_eval_batch_size=args.batch_size,
    warmup_steps=500,
    weight_decay=args.weight_decay,
    logging_dir=model_directory,
    logging_steps=100,
    do_eval=True,
    evaluation_strategy='epoch',
    learning_rate=args.learning_rate,
    load_best_model_at_end=True, # the last checkpoint is the best model wrt metric_for_best_model
    metric_for_best_model='eval_loss',
    lr_scheduler_type = 'linear'
    greater_is_better=False, 
    save_total_limit=args.epochs if args.save_total_limit == -1 else args.save_total_limit,

)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    optimizers=[torch.optim.Adam(params=model.parameters(), 
    lr=args.learning_rate), None], // optimizers
    tokenizer=tokenizer,
)

对于其他调度程序类型,请参阅此链接:https://huggingface.co/transformers/main_classes/optimizer_schedules.html