如何使用拥抱型脸部变压器训练编码器-解码器模型以完成翻译任务?

时间:2020-06-18 09:31:06

标签: huggingface-transformers machine-translation encoder-decoder

我想训练一个编码器解码器模型,如下所述,用于翻译任务。有人可以指导我如何为这种模型建立培训管道吗?任何链接或代码段都将不胜感激。

class SafeAreaPersistentHeaderDelegate extends SliverPersistentHeaderDelegate {
  final Widget title;

  final Widget flexibleSpace;

  final double expandedHeight;

  SafeAreaPersistentHeaderDelegate(
      {this.title, this.flexibleSpace, this.expandedHeight});

  @override
  Widget build(
      BuildContext context, double shrinkOffset, bool overlapsContent) {
    final Widget appBar = FlexibleSpaceBar.createSettings(
      minExtent: minExtent,
      maxExtent: maxExtent,
      currentExtent: max(minExtent, maxExtent - shrinkOffset),
      toolbarOpacity: 1,
      child: AppBar(
        backgroundColor: Colors.blue,
          automaticallyImplyLeading: false,
          title: title,
          flexibleSpace: (title == null && flexibleSpace != null)
              ? Semantics(child: flexibleSpace, header: true)
              : flexibleSpace,
          toolbarOpacity: 1,
          bottomOpacity: 1.0),
    );
    return appBar;
  }

  @override
  double get maxExtent => expandedHeight;

  @override
  double get minExtent => kToolbarHeight;

  @override
  bool shouldRebuild(SafeAreaPersistentHeaderDelegate old) {
    if (old.flexibleSpace != flexibleSpace) {
      return true;
    }
    return false;
  }
}

1 个答案:

答案 0 :(得分:2)

编码器-解码器模型的使用方式与 Transformers 中的任何其他模型相同。它接受成批标记化文本作为词汇索引(即,您需要一个适合您的序列到序列任务的标记器)。当您使用输入 (input_ids) 和所需输出(decoder_input_idslabels)为模型提供数据时,您将获得可以在训练期间优化的损失值。请注意,如果批次中的句子长度不同,您也需要进行屏蔽。这是 EncoderDecoderModel 文档的最小示例:

from transformers import EncoderDecoderModel, BertTokenizer
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    'bert-base-uncased', 'bert-base-uncased')
input_ids = torch.tensor(
    tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)
outputs = model(
    input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids, 
    return_dict=True)
loss = outputs.loss

如果您不想自己编写训练循环,可以使用 Transformers 的数据集处理 (DataCollatorForSeq2Seq) 和训练 (Seq2SeqTrainer) 实用程序。您可以关注Seq2Seq example on GitHub