如何使用tf.train.Saver()方法recover_last_checkpoints?

时间:2017-07-13 10:03:26

标签: python tensorflow

文档写道应该将检查点路径列表传递给它,但如何获取列表?通过硬编码?不,这是一种愚蠢的做法。通过解析协议缓冲区文件(模型目录中名为checkpoint的文件)?但是tensorflow没有实现解析器,是吗?所以我必须自己实施一个吗? 您是否有良好的做法来获取检查点路径列表?

我提出这个问题,因为这些天我被一件事困扰。如您所知,由于某种原因,为期数天的培训可能会崩溃,我必须从最新的检查点恢复。恢复培训很简单,因为我只需要编写以下代码:

restorer = tf.train.Saver()
restorer.restore(sess, latest_checkpoint)

我可以硬编码latest_checkpoint,或者稍微更聪明,使用tf.train.latest_checkpoint()

然而,在我恢复训练后出现问题。那些在崩溃之前创建的旧检查点文件留在那里。 Saver仅管理在一次运行中创建的检查点文件。我希望它也可以管理以前创建的检查点文件,这样它们就会被自动删除,而且我不必每次都手动删除它们。我认为这种重复工作真的很傻。

然后我在类recover_last_checkpoints中找到tf.train.Saver()方法,它允许Saver管理旧检查点。但它使用起来并不方便。那么有什么好的解决方案吗?

1 个答案:

答案 0 :(得分:1)

正如@isarandi在评论中所提到的,最简单的方法是先使用get_checkpoint_state‌后跟all_model_checkpoi‌​nt_paths恢复所有检查点路径,这基本上是一个未记录的功能。然后,您可以恢复最新状态:

states = tf.train.get_checkpoint_state‌​(your_checkpoint_dir‌​)
checkpoint_paths = states.all_model_checkpoi‌​nt_paths
saver.recover_last_checkpoints(checkpoint_paths)