如何为tf.data.Dataset创建性能时间表?

时间:2018-01-13 17:40:50

标签: tensorflow tensorflow-datasets

我试图在Tensorflow Dataset pipline中查看需要这么长时间的内容,不幸的是,当我运行分析时,我的数据集的整个执行都被一个操作所覆盖:" IteratorGetNext"。有没有办法窥视数据集图形内部以分别查看每个地图?

这是一个简单的例子,可以通过添加num_parallel_calls来更快地运行,但遗憾的是,当整个操作出现时,人们无法从时间线告诉它(见截图)

import tensorflow as tf
from tensorflow.python.ops import io_ops
from tensorflow.contrib.framework.python.ops import audio_ops

g = tf.Graph()
with g.as_default():
  ds = tf.data.Dataset.list_files("work/input/train/audio/**/*.wav")
  ds = (ds
        .map(lambda x: io_ops.read_file(x))
        .map(lambda x: audio_ops.decode_wav(x,
                                 desired_channels=1,
                                 desired_samples=16000))
        .batch(30*1000)
        .prefetch(2)
  )

  iterator = ds.make_one_shot_iterator()
  get_next = iterator.get_next()


run_metadata = tf.RunMetadata()
run_config = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)


with tf.Session(graph=g) as sess:
  sess.run(get_next,
           options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
           run_metadata=run_metadata)

from tensorflow.python.client import timeline
trace = timeline.Timeline(step_stats=run_metadata.step_stats)

trace_file = open('timelines/example.json', 'w')
trace_file.write(trace.generate_chrome_trace_format())
trace_file.close()

enter image description here

0 个答案:

没有答案