我有一个包含多个变量的模型,包括一个全局步骤。我已经能够成功地使用MonitoredSession来保存每100个步骤的检查点和摘要。我期待MonitoredSession在会话以多次传递的方式运行时自动恢复所有变量(基于this文档),但这不会发生。如果我再次运行训练课程后查看全局步骤,我会发现它从零开始。这是我的代码的简化版本,没有实际的模型。如果需要更多代码来解决此问题,请告诉我
train_graph = tf.Graph()
with train_graph.as_default():
# I create some datasets using the Dataset API
# ...
global_step = tf.train.create_global_step()
# Create all the other variables and the model here
# ...
saver_hook = tf.train.CheckpointSaverHook(
checkpoint_dir='checkpoint/',
save_secs=None,
save_steps=100,
saver=tf.train.Saver(),
checkpoint_basename='model.ckpt',
scaffold=None)
summary_hook = tf.train.SummarySaverHook(
save_steps=100,
save_secs=None,
output_dir='summaries/',
summary_writer=None,
scaffold=None,
summary_op=train_step_summary)
num_steps_hook = tf.train.StopAtStepHook(num_steps=500) # Just for testing
with tf.train.MonitoredSession(
hooks=[saver_hook, summary_hook, num_steps_hook]) as sess:
while not sess.should_stop():
step = sess.run(global_step)
if (step % 100 == 0):
print(step)
sess.run(optimizer)
当我第一次运行此代码时,我得到了这个输出
0
100
200
300
400
此时检查点文件夹的每百分之一检查点最多500个。如果我再次运行程序,我希望看到计数器从500开始,增加到900,但我只是得到相同的东西我的所有检查站都被覆盖了。有什么想法吗?
答案 0 :(得分:0)
好吧,我明白了。它实际上非常简单。首先,使用MonitoredTraningSession()而不是MonitoredSession()更容易。此包装器会话将参数'checkpoint_dir'作为参数。我认为saver_hook会照顾恢复,但事实并非如此。为了解决我的问题,我只需更改我定义会话的行,如下所示:
with tf.train.MonitoredTrainingSession(hooks=[saver_hook, summary_hook], checkpoint_dir='checkpoint'):
也可以直接使用MonitoredSession完成,但您需要设置session_creator。