保存检查点并在tensorflow中恢复训练

时间:2016-04-19 23:15:49

标签: python-2.7 tensorflow

我正在玩保存检查点并从已保存的检查点恢复训练。我正在关注 - https://www.tensorflow.org/versions/r0.8/api_docs/python/train.html#import_meta_graph中给出的示例 为了简单起见,我没有使用任何真实的'培训网络。我刚刚执行了一个简单的减法操作,每个检查点一次又一次地在相同的张量上保存相同的操作。 以下ipython笔记本的形式提供了一个最小的示例 - https://gist.github.com/dasabir/29b8f84c6e5e817a72ce06584e988f10

在第一阶段,我运行循环100次(通过设置变量的值' endIter = 100'在代码中)并每隔10次迭代保存检查点。因此,保存的检查点编号为 - 9,19,...,99。现在,当我更改' enditer'值得说200并且恢复训练,检查站再次开始从9,19,...(不是109,119,129,......)保存。有没有我失踪的伎俩?

1 个答案:

答案 0 :(得分:4)

你可以打印出#latest; latest_ckpt',看看它是否指向最新的ckpt文件?此外,您需要使用tf.variable:

维护global_step
global_step = tf.Variable(0, name='global_step', trainable=False)
...
ckpt = tf.train.get_checkpoint_state(ckpt_dir)
if ckpt and ckpt.model_checkpoint_path:
    print ckpt.model_checkpoint_path
    saver.restore(sess, ckpt.model_checkpoint_path) # restore all variables
start = global_step.eval() # get last global_step
print "Start from:", start

for i in range(start, 100):
...
    global_step.assign(i).eval() # set and update(eval) global_step with index, i
    saver.save(sess, ckpt_dir + "/model.ckpt", global_step=global_step)

您可以查看完整示例:

https://github.com/nlintz/TensorFlow-Tutorials/pull/32/files