如何从Tensorflow检查点(ckpt)文件预测基于BERT的句子中的被屏蔽单词?

时间:2019-09-11 12:22:49

标签: python tensorflow deep-learning predict bert-language-model

我有基于BERT的模型检查点,我在Tensorflow中从头开始进行了训练。如何使用这些检查点来预测给定句子中的被掩盖词?

比如说句子是 “ [CLS] abc pqr [MASK] xyz [SEP]” 我想在[MASK]位置预测单词。

我该怎么办? 我在网上搜索了很多内容,但是每个人都在使用BERT来完成特定于任务的分类任务。 不使用BERT来预测被屏蔽的单词。

请帮助我解决此预测问题。

我使用create_pretraining_data.py创建了数据,并使用了来自官方BERT repo(https://github.com/google-research/bert)的run_pretraining.py从头开始训练模型

我已经搜索了官方bert回购中的问题。但是没有找到任何解决方案。

还查看了该存储库中的代码。他们使用的是Estimator,而不是从检查点权重开始训练的。

找不到任何方法来使用基于BERT的模型的Tensorflow检查点(从头开始训练)来预测字掩码令牌(即[MASK])。

1 个答案:

答案 0 :(得分:0)

您肯定需要从TF检查点开始吗?如果可以使用pytorch-transformers库中使用的一种预训练模型,那么我编写了一个用于执行此操作的库:FitBERT

如果必须以TF检查点开头,则有一些脚本可以将TF检查点转换为pytorch-transformers可以使用的脚本,link,并且在转换之后,您应该可以使用FitBERT,或者您可以看看我们在代码中正在做什么。