有没有tf.train.SessionRunHook的教程?

时间:2017-08-06 13:11:01

标签: tensorflow

阅读API DOC后,我也无法理解SessionRunHook的用法。例如,SessionRunHook的成员序列是什么 函数被调用?是after_create_session -> before_run -> begin -> after_run -> end吗? 我无法通过详细的例子找到教程,是否有更详细的解释?

1 个答案:

答案 0 :(得分:24)

你可以找一个教程here,但有一点时间你可以跳过构建网络的一部分。或者你可以根据我的经验阅读下面的小摘要。

首先,应使用MonitoredSession代替普通Session

  

SessionRunHook扩展session.run()的{​​{1}}次调用。

然后可以找到一些常见的MonitoredSessionhere。一个简单的问题是SessionRunHook,但您可能希望在导入后添加以下行,以便在运行时查看日志:

LoggingTensorHook

或者您可以选择实施自己的tf.logging.set_verbosity(tf.logging.INFO) 课程。一个简单的来自cifar10 tutorial

SessionRunHook

其中class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs(loss) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print (format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) 在课外定义。在loss使用_LoggerHook时,print使用LoggingTensorHook打印信息。

最后,为了更好地理解它的工作原理,执行顺序由伪代码tf.logging.INFO here表示:

MonitoredSession

希望这有帮助。