阅读API DOC后,我也无法理解SessionRunHook的用法。例如,SessionRunHook的成员序列是什么
函数被调用?是after_create_session -> before_run -> begin -> after_run -> end
吗?
我无法通过详细的例子找到教程,是否有更详细的解释?
答案 0 :(得分:24)
你可以找一个教程here,但有一点时间你可以跳过构建网络的一部分。或者你可以根据我的经验阅读下面的小摘要。
首先,应使用MonitoredSession
代替普通Session
。
SessionRunHook扩展
session.run()
的{{1}}次调用。
然后可以找到一些常见的MonitoredSession
类here。一个简单的问题是SessionRunHook
,但您可能希望在导入后添加以下行,以便在运行时查看日志:
LoggingTensorHook
或者您可以选择实施自己的tf.logging.set_verbosity(tf.logging.INFO)
课程。一个简单的来自cifar10 tutorial
SessionRunHook
其中class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
在课外定义。在loss
使用_LoggerHook
时,print
使用LoggingTensorHook
打印信息。
最后,为了更好地理解它的工作原理,执行顺序由伪代码tf.logging.INFO
here表示:
MonitoredSession
希望这有帮助。