如何使用训练有素的BERT模型检查点进行预测?

时间:2019-06-28 04:28:46

标签: python tensorflow neural-network google-cloud-tpu bert-language-model

我用SQUAD 2.0训练了BERT,并得到了model.ckpt.data,model.ckpt.meta。使用BERT-master / run_squad.py

将输出目录中的model.ckpt.index(F1分数:81)以及projections.json等
python run_squad.py \
  --vocab_file=$BERT_LARGE_DIR/vocab.txt \
  --bert_config_file=$BERT_LARGE_DIR/bert_config.json \
  --init_checkpoint=$BERT_LARGE_DIR/bert_model.ckpt \
  --do_train=True \
  --train_file=$SQUAD_DIR/train-v2.0.json \
  --do_predict=True \
  --predict_file=$SQUAD_DIR/dev-v2.0.json \
  --train_batch_size=24 \
  --learning_rate=3e-5 \
  --num_train_epochs=2.0 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=gs://some_bucket/squad_large/ \
  --use_tpu=True \
  --tpu_name=$TPU_NAME \
  --version_2_with_negative=True

我试图将model.ckpt.meta,model.ckpt.index,model.ckpt.data复制到$ BERT_LARGE_DIR目录,并按如下所示更改run_squad.py标志以仅预测答案而不使用数据集进行训练:

python run_squad.py \
  --vocab_file=$BERT_LARGE_DIR/vocab.txt \
  --bert_config_file=$BERT_LARGE_DIR/bert_config.json \
  --init_checkpoint=$BERT_LARGE_DIR/model.ckpt \
  --do_train=False \
  --train_file=$SQUAD_DIR/train-v2.0.json \
  --do_predict=True \
  --predict_file=$SQUAD_DIR/dev-v2.0.json \
  --train_batch_size=24 \
  --learning_rate=3e-5 \
  --num_train_epochs=2.0 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=gs://some_bucket/squad_large/ \
  --use_tpu=True \
  --tpu_name=$TPU_NAME \
  --version_2_with_negative=True

它抛出bucket directory / model.ckpt不存在错误。

如何利用训练后生成的检查点并将其用于预测?

2 个答案:

答案 0 :(得分:1)

通常,训练时在const [searchValue, setSearchValue] = useState(''); const [data, setData] = useState(null); function handleInputChange({target: {value}}) { setSearchValue(value); } function fetchData(text) { const url = 'https://www.google.com/?search=' + text; axios.get(url).then(({data}) => setData({data})); } useEffect(() => { // don't run on componentDidMount or if string is empty if (searchValue) { fetchData(searchValue); } }, [searchValue]); 参数指定的目录中创建训练后的检查点。 (在您的情况下为gs:// some_bucket / squad_large /)。每个检查点都会有一个数字。您必须确定最大的数字。例如:--output_dir。现在,使用输出目录和最后保存的检查点(编号最大的模型)在评估/预测中设置model.ckpt-12345参数。 (对于您来说,它应该类似于--init_checkpoint

答案 1 :(得分:0)

在第二个代码中,标志/Users/xxx/projects/kubernetes-test/tasks应该是:

init_checkpoint

与上面的相同,而不是--init_checkpoint=$BERT_LARGE_DIR/bert_model.ckpt

如果问题仍然存在,您是否正在使用--init_checkpoint=$BERT_LARGE_DIR/model.ckpt预先训练的模型?

相关问题