使用tensorflow中的monitoredtrainingsession获取运行时统计信息

时间:2017-07-24 21:22:07

标签: tensorflow statistics runtime profile

我正在尝试按照运行时统计信息指令here获取我的张量流代码配置文件(网络中每个层的运行和内存消耗)。据我所知,我需要创建运行选项并像这样运行元数据

run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()

并将其传递给sess.run

然而,由于我也在尝试使用tf.train.MonitoredTrainingSession,我不知道是否可以将同样的内容传递给此类。一种似是而非的方法可以使用Hooks,但我不知道该怎么做。我对他们来说还是个新手

1 个答案:

答案 0 :(得分:3)

您只需创建一个自定义挂钩并将其传递给tf.RunMetadata()即可。无需将自己的import tensorflow as tf class TraceHook(tf.train.SessionRunHook): """Hook to perform Traces every N steps.""" def __init__(self, ckptdir, every_step=50, trace_level=tf.RunOptions.FULL_TRACE): self._trace = every_step == 1 self.writer = tf.summary.FileWriter(ckptdir) self.trace_level = trace_level self.every_step = every_step def begin(self): self._global_step_tensor = tf.train.get_global_step() if self._global_step_tensor is None: raise RuntimeError("Global step should be created to use _TraceHook.") def before_run(self, run_context): if self._trace: options = tf.RunOptions(trace_level=self.trace_level) else: options = None return tf.train.SessionRunArgs(fetches=self._global_step_tensor, options=options) def after_run(self, run_context, run_values): global_step = run_values.results - 1 if self._trace: self._trace = False self.writer.add_run_metadata(run_values.run_metadata, f'{global_step}', global_step) if not (global_step + 1) % self.every_step: self._trace = True 实例传递给运行调用。

以下是一个示例Hook,它将每N个步骤的元数据存储到ckptdir:

before_run

检入after_run是否必须跟踪,如果是,则添加RunOptions。在_trace中,它检查是否需要跟踪下一个运行调用,如果需要,它会再次将$output .= "<div class=\"fashmag-gallery-half\" data-slick=\"{\"slidesToShow\": 4}\">"; 设置为True。此外,它还会在元数据可用时存储元数据。