在TensorFlow中获取CostModel统计数据/估算值?

时间:2016-10-24 09:53:58

标签: python tensorflow

是否有一种简单的方法可以获取TensorFlow中的输入/输出张量的时间估计,总计数或内存估计等统计数据/估计值,例如通过python API?

1 个答案:

答案 0 :(得分:0)

要做的第一件事是启用成本模型的收集:

# Collect and aggregate statistics every 50 iterations
options = tf.GraphOptions(build_cost_model=50)
cfg = tf.ConfigProto(graph_options=options)
sess = tf.Session(config=cfg)

然后,您可以生成成本模型的更新版本,如下所示:

metadata = tf.RunMetadata()
# This is optional, but will generally give you more accurate statistics,
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)

for step in xrange(0, 1000):
  _ = sess.run([train_op], options=run_options, run_metadata=metadata)

  if len(metadata.cost_graph.node) > 0:
    print ("HERE IS THE COST GRAPH " + str(metadata.cost_graph))