`MonitoredTrainingSession()`如何与“恢复”和“测试模式”一起工作?

时间:2017-03-29 22:04:38

标签: python session tensorflow distributed restore

在Tensorflow中,我们可以使用Between-graph Replication构建和创建多个Tensorflow会话,以进行分布式培训。 MonitoredTrainingSession()协调多个Tensorflow会话,checkpoint_dir有一个参数MonitoredTrainingSession()来恢复Tensorflow会话/图表。现在我有以下问题:

  1. 我们通常使用tf.train.Saver()的对象按saver.restore(...)恢复Tensorflow图。但是,我们如何使用MonitoredTrainingSession()
  2. 恢复它们
  3. 由于我们运行多个流程并且每个流程都构建并创建了一个Tensorflow会话用于培训,我想知道我们是否还必须在培训后运行多个流程进行测试(或预测)。换句话说,MonitoredTrainingSession()如何使用测试(或预测)模式?
  4. 我阅读了Tensorflow Doc,但未找到这两个问题的答案。如果有人有解决方案我真的很感激。谢谢!

3 个答案:

答案 0 :(得分:3)

简短回答:

  1. 您需要将全局步骤传递给传递给mon_sess.run的优化程序。这样就可以保存和检索已保存的检查点。
  2. 可以通过单个MonitoredTrainingSession同时运行培训+交叉验证会话。首先,您需要通过相同图形的单独流传递训练批次和交叉验证批次(我建议您查找this guide以获取有关如何执行此操作的信息)。其次,您必须 - 到mon_sess.run() - 传递训练流的优化器,以及交叉验证流的损失(/您想要跟踪的参数)的参数。如果要与训练分开运行测试会话,只需通过图形运行测试集,并仅通过图形运行test_loss(/您想要跟踪的其他参数)。有关如何完成此操作的详细信息,请参见下文。
  3. 答案很长:

    我将更新我的答案,因为我自己可以更好地了解tf.train.MonitoredSession可以做些什么(tf.train.MonitoredTrainingSession只是创建一个专门版本的tf.train.MonitoredSession,可以看到在source code)。

    以下示例代码显示了如何每5秒将检查点保存到' ./ ckpt_dir'。中断后,它将在上次保存的检查点重新启动:

    def train(inputs, labels_onehot, global_step):
        out = tf.contrib.layers.fully_connected(
                                inputs,
                                num_outputs=10,
                                activation_fn=tf.nn.sigmoid)
        loss = tf.reduce_mean(
                 tf.reduce_sum(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                                logits=out,
                                labels=labels_onehot), axis=1))
        train_op = opt.minimize(loss, global_step=global_step)
        return train_op
    
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()
        inputs = ...
        labels_onehot = ...
        train_op = train(inputs, labels_onehot, global_step)
    
        with tf.train.MonitoredTrainingSession(
            checkpoint_dir='./ckpt_dir',
            save_checkpoint_secs=5,
            hooks=[ ... ] # Choose your hooks
        ) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
    

    为了实现这一点,MonitoredTrainingSession中发生的事情实际上是三件事:

    1. tf.train.MonitoredTrainingSession创建一个tf.train.Scaffold对象,其作用类似于网络中的蜘蛛;它收集你训练,保存和加载模型所需的部分。
    2. 它会创建一个tf.train.ChiefSessionCreator对象。我对这个的了解是有限的,但根据我对它的理解,它用于当你的tf算法分布在多个服务器上时。我的看法是它告诉运行该文件的计算机它是主计算机,并且在这里应该保存检查点目录,并且记录器应该在这里记录它们的数据等。
    3. 它会创建一个tf.train.CheckpointSaverHook,用于保存检查点。
    4. 为了使其工作,必须将tf.train.CheckpointSaverHook和tf.train.ChiefSessionCreator传递给checkpoint目录和scaffold的相同引用。如果上面的例子中的tf.train.MonitoredTrainingSession及其参数是用上面的3个组件实现的,那么它看起来像这样:

      checkpoint_dir = './ckpt_dir'
      
      scaffold = tf.train.Scaffold()
      saverhook = tf.train.CheckpointSaverHook(
          checkpoint_dir=checkpoint_dir,
          save_secs=5
          scaffold=scaffold
      )
      session_creator = tf.train.ChiefSessionCreator(
          scaffold=scaffold,
          checkpoint_dir=checkpoint_dir
      )
      
      with tf.train.MonitoredSession(
          session_creator=session_creator,
          hooks=[saverhook]) as mon_sess:
              while not mon_sess.should_stop():
                  mon_sess.run(train_op)
      

      为了进行训练+交叉验证会话,您可以将tf.train.MonitoredSession.run_step_fn()与partial一起使用,该部分运行会话调用而不调用任何挂钩。这看起来是你为 n 迭代训练你的模型,然后运行你的测试集,重新初始化你的迭代器并返回训练你的模型等等。当然,你必须设置你的要执行此操作时要重用的变量= tf.AUTO_REUSE。在代码中执行此操作的方法如下所示:

      from functools import partial
      
      # Build model
      ...
      
      with tf.variable_scope(..., reuse=tf.AUTO_REUSE):
          ...
      
      ...
      
      def step_fn(fetches, feed_dict, step_context):
          return step_context.session.run(fetches=fetches, feed_dict=feed_dict)
      
      with tf.train.MonitoredTrainingSession(
                      checkpoint_dir=...,
                      save_checkpoint_steps=...,
                      hooks=[...],
                      ...
                      ) as mon_sess:
      
                      # Initialize iterators (assuming tf.Databases are used)
                      mon_sess.run_step_fn(
                                 partial(
                                     step_fn, 
                                     [train_it.initializer, 
                                      test_it.initializer, 
                                      ...
                                     ], 
                                     {}
                                  )
                      )
      
                      while not mon_sess.should_stop():
                          # Train session
                          for i in range(n):
                              try:
                                  train_results = mon_sess.run(<train_fetches>)
                              except Exception as e:
                                  break
      
                          # Test session
                          while True:
                              try:
                                  test_results = mon_sess.run(<test_fetches>)
                              except Exception as e:
                                  break
      
                          # Reinitialize parameters
                          mon_sess.run_step_fn(
                                     partial(
                                        step_fn, 
                                        [train_it.initializer, 
                                         test_it.initializer, 
                                         ...
                                        ], 
                                        {}
                                     )
                          )
      

      partial函数只是在step_fn上执行currying(函数式编程中的经典函数),在mon_sess.run_step_fn()中使用。上面的整个代码尚未经过测试,您可能必须在开始测试会话之前重新初始化train_it,但希望现在很清楚如何在同一个运行中运行训练集和验证集。如果你想在同一个图中同时绘制训练曲线和测试曲线,这可以与张量板的custom_scalar tool一起使用。

      最后,这是我能够做到的这个功能的最佳实现,我个人希望tensorflow将来更容易实现这个功能,因为它非常繁琐,可能效率不高。我知道有Estimator这样的工具可以运行train_and_evaluate函数,但是由于这会在每次列车验证和交叉验证运行之间重建图形,如果你只运行一个,它的效率非常低单台电脑。我在某处看到Keras + tf具有此功能,但由于我不使用Keras + tf,因此这不是一个选项。无论如何,我希望这可以帮助其他人在那里挣扎同样的事情!

答案 1 :(得分:0)

您应该导入元图,然后恢复模型。 从下面的片段中获取灵感,应该适合你。

    self.sess = tf.Session()
    ckpt = tf.train.latest_checkpoint("location-of/model")
    saver = tf.train.import_meta_graph(ckpt + '.meta', clear_devices=True)
    saver.restore(self.sess, ckpt)

答案 2 :(得分:-1)

  1. 似乎正在为您处理恢复。在API docs中,它说调用MonitoredTrainingSession会返回一个MonitoredSession实例,在创建时“......如果检查点存在则恢复变量......”

  2. 查看tf.contrib.learn.Estimator(..).predict(..)以及更具体的tf.contrib.learn.Estimator(..)._infer_model(..)方法herehere。他们还在那里创建了一个MonitoredSession。