是否有一种简单的方法可以获取TensorFlow中的输入/输出张量的时间估计,总计数或内存估计等统计数据/估计值,例如通过python API?
答案 0 :(得分:0)
要做的第一件事是启用成本模型的收集:
# Collect and aggregate statistics every 50 iterations
options = tf.GraphOptions(build_cost_model=50)
cfg = tf.ConfigProto(graph_options=options)
sess = tf.Session(config=cfg)
然后,您可以生成成本模型的更新版本,如下所示:
metadata = tf.RunMetadata()
# This is optional, but will generally give you more accurate statistics,
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
for step in xrange(0, 1000):
_ = sess.run([train_op], options=run_options, run_metadata=metadata)
if len(metadata.cost_graph.node) > 0:
print ("HERE IS THE COST GRAPH " + str(metadata.cost_graph))