如何调试评估节点的原因?

时间:2017-01-14 13:49:06

标签: tensorflow

在Tensorflow中,只有在目标取决于它时,才能提供占位符:

x = tf.placeholder(tf.int32, [], "x")
y = 2 * x1
y = tf.Print(y, ["Computed y"])
z = 2 * y

# Error: should feed "x"
z.eval()

# OK, because y is not actually computed
z.eval({y: 1})

现在,在我更复杂的图表中,我遇到的问题是我得到了一些错误,一些占位符没有被输入,但我认为它们不应该被需要,通过上面说明的相同机制。

我该如何调试?错误消息仅说明需要哪个占位符,而不是原因。获取从占位符到目标的路径会很有帮助。

如何获取此信息?

1 个答案:

答案 0 :(得分:2)

如果图形不大,您可以从目标节点

进行反向图搜索

即,

def find(start, target):
    """Returns path to parent from given start node"""
    if start == target:
        return [target]
    for parent in start.op.inputs:
        found_path = find(parent, target)
        if found_path:
            return [start]+found_path
    return []

使用它

tf.reset_default_graph()
a1 = tf.ones(())
b1 = tf.ones(())
a2 = 2*a1
b2 = 2*b1
a3 = 2*a2
b3 = 2*b2
d4 = b3+a3
find(d4, a1)

应该返回

[<tf.Tensor 'add:0' shape=() dtype=float32>,
 <tf.Tensor 'mul_2:0' shape=() dtype=float32>,
 <tf.Tensor 'mul:0' shape=() dtype=float32>,
 <tf.Tensor 'ones:0' shape=() dtype=float32>]

如果图表很大,您可以通过将搜索限制在它们之间的操作来加快速度

import tensorflow.contrib.graph_editor as ge
ops_between = ge.get_walks_intersection_ops(source, target)

ge.get_walks_intersection_ops doc