在Tensorflow中,只有在目标取决于它时,才能提供占位符:
x = tf.placeholder(tf.int32, [], "x")
y = 2 * x1
y = tf.Print(y, ["Computed y"])
z = 2 * y
# Error: should feed "x"
z.eval()
# OK, because y is not actually computed
z.eval({y: 1})
现在,在我更复杂的图表中,我遇到的问题是我得到了一些错误,一些占位符没有被输入,但我认为它们不应该被需要,通过上面说明的相同机制。
我该如何调试?错误消息仅说明需要哪个占位符,而不是原因。获取从占位符到目标的路径会很有帮助。
如何获取此信息?
答案 0 :(得分:2)
如果图形不大,您可以从目标节点
进行反向图搜索即,
def find(start, target):
"""Returns path to parent from given start node"""
if start == target:
return [target]
for parent in start.op.inputs:
found_path = find(parent, target)
if found_path:
return [start]+found_path
return []
使用它
tf.reset_default_graph()
a1 = tf.ones(())
b1 = tf.ones(())
a2 = 2*a1
b2 = 2*b1
a3 = 2*a2
b3 = 2*b2
d4 = b3+a3
find(d4, a1)
应该返回
[<tf.Tensor 'add:0' shape=() dtype=float32>,
<tf.Tensor 'mul_2:0' shape=() dtype=float32>,
<tf.Tensor 'mul:0' shape=() dtype=float32>,
<tf.Tensor 'ones:0' shape=() dtype=float32>]
如果图表很大,您可以通过将搜索限制在它们之间的操作来加快速度
import tensorflow.contrib.graph_editor as ge
ops_between = ge.get_walks_intersection_ops(source, target)