TensorFlow:在MonitoredSession中恢复模型

时间:2017-12-21 21:09:02

标签: python tensorflow

我有一个包含多个变量的模型,包括一个全局步骤。我已经能够成功地使用MonitoredSession来保存每100个步骤的检查点和摘要。我期待MonitoredSession在会话以多次传递的方式运行时自动恢复所有变量(基于this文档),但这不会发生。如果我再次运行训练课程后查看全局步骤,我会发现它从零开始。这是我的代码的简化版本,没有实际的模型。如果需要更多代码来解决此问题,请告诉我

train_graph = tf.Graph()
with train_graph.as_default():
  # I create some datasets using the Dataset API
  # ...

  global_step = tf.train.create_global_step()

  # Create all the other variables and the model here
  # ...

  saver_hook = tf.train.CheckpointSaverHook(
      checkpoint_dir='checkpoint/',
      save_secs=None,
      save_steps=100,
      saver=tf.train.Saver(),
      checkpoint_basename='model.ckpt',
      scaffold=None)
  summary_hook = tf.train.SummarySaverHook(
      save_steps=100,
      save_secs=None,
      output_dir='summaries/',
      summary_writer=None,
      scaffold=None,
      summary_op=train_step_summary)
  num_steps_hook = tf.train.StopAtStepHook(num_steps=500) # Just for testing


  with tf.train.MonitoredSession(
      hooks=[saver_hook, summary_hook, num_steps_hook]) as sess:
    while not sess.should_stop():
      step = sess.run(global_step)
      if (step % 100 == 0):
        print(step)
      sess.run(optimizer)

当我第一次运行此代码时,我得到了这个输出

0
100
200
300
400

此时检查点文件夹的每百分之一检查点最多500个。如果我再次运行程序,我希望看到计数器从500开始,增加到900,但我只是得到相同的东西我的所有检查站都被覆盖了。有什么想法吗?

1 个答案:

答案 0 :(得分:0)

好吧,我明白了。它实际上非常简单。首先,使用MonitoredTraningSession()而不是MonitoredSession()更容易。此包装器会话将参数'checkpoint_dir'作为参数。我认为saver_hook会照顾恢复,但事实并非如此。为了解决我的问题,我只需更改我定义会话的行,如下所示:

with tf.train.MonitoredTrainingSession(hooks=[saver_hook, summary_hook], checkpoint_dir='checkpoint'):

也可以直接使用MonitoredSession完成,但您需要设置session_creator。