使用保护程序参数`keep_checkpoint_every_n_hours'的正确方法是什么?

时间:2019-05-13 05:54:35

标签: python tensorflow

我想在训练期间每5个小时保存一次我的tensorflow变量。

因此,根据Tensorflow Saver doc,我用参数saver = Saver构造了keep_checkpoint_every_n_hours=5,并在学习的每一步都调用了saver.save()。因此,我期望的是saver模块以某种方式检测开始训练后经过的时间,并每5个小时保存一次模型,而不是实际上在每次调用模型时都保存模型。

下面简化了我如何使用此功能。

sess = tf.Session()
model = Model(sess)
saver = tf.train.Saver(max_to_keep=5,
                       keep_checkpoint_every_n_hours=5)

step_count = 0
max_step = 10000

while step_count < max_step:
    model.train()
    saver.save(sess, 'model', global_step=step_count)
    step_count += 1

但是,我发现通过这种方式,每次调用函数时都会保存模型。

我想我缺少了一些东西或没有以正确的方式使用它。 我想知道使用te​​nsorflow Saver功能的正确方法。

谢谢。

1 个答案:

答案 0 :(得分:1)

max_to_keep参数指定每次保存一个检查点,但一次最多保存5个检查点。

另一方面,

keep_checkpoint_every_n_hours指定每N小时将保存一个检查点,并且不会将其删除或覆盖。

因此,我建议您使用keep_checkpoint_every_n_hours,以防您的培训花费很长时间并且可能会有所不同。因此,如果最后5个检查点变得同样无用,您可以还原到最多N小时之前的一个。