我使用MonitoredTrainingSession()
训练了一个模型,其中检查点保护钩tf.train.CheckpointSaverHook()
每1000步保存一次检查点。训练后,在检查点目录中创建了以下文件:
events.out.tfevents.1511969396.cmle-training-master-ef2237c814-0-xn7pp
graph.pbtxt
model.ckpt-1.meta
model.ckpt-1001.meta
model.ckpt-2001.meta
model.ckpt-3001.meta
model.ckpt-4001.meta
model.ckpt-4119.meta
我想恢复检查点,但不能,这是我的代码(假设上面的文件在checkpoints
目录中):
tf.train.import_meta_graph('checkpoints/model.ckpt-4139.meta')
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('./checkpoints/')
saver.restore(sess, ckpt.model_checkpoint_path)
问题是ckpt
是None
,我想我可能错过了一个文件......我做错了什么。
这是我保存检查点的方式:
hooks=lists()
hooks.append(tf.train.CheckpointSaverHook(checkpoint_dir=checkpoint_dir, save_steps=checkpoint_iterations)
with tf.Graph().as_default():
with tf.device(tf.train.replica_device_setter()):
batch = model.input_fn(train_path, batch_size, epochs, 'train_queue')
tensors = model.model_fn(batch, content_weight, style_weight, tv_weight, vgg_path, style_features,
batch_size, learning_rate)
with tf.train.MonitoredTrainingSession(master=target,
is_chief=is_chief,
checkpoint_dir=job_dir,
hooks=hooks,
save_checkpoint_secs=None,
save_summaries_steps=None,
log_step_count_steps=10) as sess:
_ = sess.run(tensors)
(...)
答案 0 :(得分:2)
tf.train.get_checkpoint_state
检查您传递的目录中的checkpoint
(无扩展名)文件作为参数。
此文件通常包含类似于:
model_checkpoint_path: "model.ckpt-1"
all_model_checkpoint_paths: "model.ckpt-1"
如果缺少此文件,该函数将返回None
。
将具有该名称和内容的文本文件添加到模型文件夹中,您将能够使用已有的代码进行恢复。
非常重要的提示:要恢复这种方式,您需要所有检查点数据,即三个文件:.data-*
,.meta
和.index
。
但是,如果您只想恢复元图,则可以通过import_meta_graph()
中的详细信息the official TF guide进行恢复。
注意(来自import_meta_graph()
的定义):
此函数将MetaGraphDef协议缓冲区作为输入。如果 参数是一个包含MetaGraphDef协议缓冲区的文件 从文件内容构造协议缓冲区。那个功能呢 将graph_def字段中的所有节点添加到当前图形中, 重新创建所有集合,并返回从中构造的保护程序 saver_def字段。
除非您在同一目录中有.index
和.data-*
个文件,否则使用该保护程序将无效。