我正在尝试按照运行时统计信息指令here获取我的张量流代码配置文件(网络中每个层的运行和内存消耗)。据我所知,我需要创建运行选项并像这样运行元数据
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
并将其传递给sess.run
然而,由于我也在尝试使用tf.train.MonitoredTrainingSession
,我不知道是否可以将同样的内容传递给此类。一种似是而非的方法可以使用Hooks,但我不知道该怎么做。我对他们来说还是个新手
答案 0 :(得分:3)
您只需创建一个自定义挂钩并将其传递给tf.RunMetadata()
即可。无需将自己的import tensorflow as tf
class TraceHook(tf.train.SessionRunHook):
"""Hook to perform Traces every N steps."""
def __init__(self, ckptdir, every_step=50, trace_level=tf.RunOptions.FULL_TRACE):
self._trace = every_step == 1
self.writer = tf.summary.FileWriter(ckptdir)
self.trace_level = trace_level
self.every_step = every_step
def begin(self):
self._global_step_tensor = tf.train.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError("Global step should be created to use _TraceHook.")
def before_run(self, run_context):
if self._trace:
options = tf.RunOptions(trace_level=self.trace_level)
else:
options = None
return tf.train.SessionRunArgs(fetches=self._global_step_tensor,
options=options)
def after_run(self, run_context, run_values):
global_step = run_values.results - 1
if self._trace:
self._trace = False
self.writer.add_run_metadata(run_values.run_metadata,
f'{global_step}', global_step)
if not (global_step + 1) % self.every_step:
self._trace = True
实例传递给运行调用。
以下是一个示例Hook,它将每N个步骤的元数据存储到ckptdir:
before_run
检入after_run
是否必须跟踪,如果是,则添加RunOptions。在_trace
中,它检查是否需要跟踪下一个运行调用,如果需要,它会再次将$output .= "<div class=\"fashmag-gallery-half\" data-slick=\"{\"slidesToShow\": 4}\">";
设置为True。此外,它还会在元数据可用时存储元数据。