如何检索提取的执行顺序?

时间:2015-11-21 18:42:36

标签: python tensorflow

给定一组提取,如何检索将在session.run(fetches)的单个调用中执行的(可能是非唯一的)提取顺序?

1 个答案:

答案 0 :(得分:4)

一个合理的解决方案是重新计算python中的拓扑排序。看起来C ++实现没有在python API中公开。如果有些情况不起作用,请告诉我。

以下是一个例子:

import tensorflow as tf
from toposort import toposort


sess = tf.InteractiveSession()

matrix1=tf.constant([[3., 3.]])
matrix2=tf.constant([[2.], [2.]])

sum = tf.add(matrix1, matrix2)
product = tf.matmul(matrix1, matrix2)
final = tf.mul(sum, product)

g = sess.graph
deps = {}

for op in g.get_operations():
    # op node
    op_inputs = set()
    op_inputs.update([t.name for t in op.inputs])
    deps[op.name] = op_inputs

    # tensor output node
    for t in op.outputs:
        deps[t.name]={op.name}
deps
{u'Add': {u'Const:0', u'Const_1:0'},
 u'Add:0': {u'Add'},
 u'Const': set(),
 u'Const:0': {u'Const'},
 u'Const_1': set(),
 u'Const_1:0': {u'Const_1'},
 u'MatMul': {u'Const:0', u'Const_1:0'},
 u'MatMul:0': {u'MatMul'},
 u'Mul': {u'Add:0', u'MatMul:0'},
 u'Mul:0': {u'Mul'}}

list(toposort(deps))
[{u'Const', u'Const_1'},
 {u'Const:0', u'Const_1:0'},
 {u'Add', u'MatMul'},
 {u'Add:0', u'MatMul:0'},
 {u'Mul'},
 {u'Mul:0'}]

随后,我们可以手动逐步评估图表中的每个节点 - 后续调用Session.run()涉及传入feed_dict,累积所有先前输入的结果。这是非常缓慢的,因为C ++和numpy数据之间不断的混乱,以及内存密集,因为我们正在缓存所有内容的输出值。