我想在训练期间每5个小时保存一次我的tensorflow变量。
因此,根据Tensorflow Saver doc,我用参数saver = Saver
构造了keep_checkpoint_every_n_hours=5
,并在学习的每一步都调用了saver.save()
。因此,我期望的是saver模块以某种方式检测开始训练后经过的时间,并每5个小时保存一次模型,而不是实际上在每次调用模型时都保存模型。
下面简化了我如何使用此功能。
sess = tf.Session()
model = Model(sess)
saver = tf.train.Saver(max_to_keep=5,
keep_checkpoint_every_n_hours=5)
step_count = 0
max_step = 10000
while step_count < max_step:
model.train()
saver.save(sess, 'model', global_step=step_count)
step_count += 1
但是,我发现通过这种方式,每次调用函数时都会保存模型。
我想我缺少了一些东西或没有以正确的方式使用它。 我想知道使用tensorflow Saver功能的正确方法。
谢谢。
答案 0 :(得分:1)
max_to_keep参数指定每次保存一个检查点,但一次最多保存5个检查点。
另一方面,keep_checkpoint_every_n_hours指定每N小时将保存一个检查点,并且不会将其删除或覆盖。
因此,我建议您使用keep_checkpoint_every_n_hours,以防您的培训花费很长时间并且可能会有所不同。因此,如果最后5个检查点变得同样无用,您可以还原到最多N小时之前的一个。