无法从张量流检查点读取以进行微调

时间:2019-08-19 18:36:23

标签: python python-3.x tensorflow nlp

我正在尝试使用预训练的BERT模型对SST2数据处理器进行微调。但是当我给出预训练模型的检查点时,它表明“在检查点中找不到关键的output_bias”。

我认为这可能是由于预训练的BERT模型检查点中的错误所致。因此,我再次进行了预培训。但是,我仍然面临着同样的问题。

TASK = 'STS' #@param {type:\"string\"}
TASK_DATA_DIR = 'glue_data/STS-B/'# + TASK

output_dir = 'trained_model/observation'
tf.gfile.MakeDirs(output_dir)

BERT_MODEL = path + 'multi_cased_L-12_H-768_A-12/' 
VOCAB_FILE = os.path.join(BERT_MODEL, 'vocab.txt')   
CONFIG_FILE = os.path.join(BERT_MODEL, 'bert_config.json')   
INIT_CHECKPOINT = os.path.join(BERT_MODEL, 'bert_model.ckpt')   
DO_LOWER_CASE = BERT_MODEL.startswith('cased')

tokenizer = tokenization.FullTokenizer(vocab_file=VOCAB_FILE, 
do_lower_case=DO_LOWER_CASE)

TRAIN_BATCH_SIZE = 1   
EVAL_BATCH_SIZE = 8   
PREDICT_BATCH_SIZE = 8   
LEARNING_RATE = 2e-5   
NUM_TRAIN_EPOCHS = 3.0   
MAX_SEQ_LENGTH = 128   

processors = {   
    "sts": run_classifier.StsProcessor,    
}   

processor = processors[TASK.lower()]()    
label_list = processor.get_labels()   

错误是:

  

NotFoundError:从检查点还原失败。这很可能   由于变量名称或其他图形键从   检查点。请确保您没有更改预期的图表   基于检查点。原始错误:找不到键output_bias   在检查点[[node save / RestoreV2(在   /home/subraas3/.conda/envs/tensorflow_13/lib/python3.7/   site-packages / tensorflow_estimator / python / estimator / estimator.py:1403)   ]] [[节点保存/还原V2(在   /home/subraas3/.conda/envs/tensorflow_13/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py:1403)   ]]

1 个答案:

答案 0 :(得分:0)

如错误消息中所指出,如果

  1. tf计算图中的图层已重命名。即预训练检查点中的图层名称与提供的API中的名称不同
  2. 将新层添加到API,即更改网络体系结构。或者,
  3. 已删除(不太可能)存在于预先训练的检查点中的图层。

请检查bert API的版本是否与预训练的检查点版本相同。如果它们相同,则可能需要使用此tool.

来手动检查检查点中的tf图是否与API一致