查看代码:
import tensorflow as tf
x = tf.constant(1.0)
y = tf.constant(2.0)
z = tf.constant(3.0)
def f1():
return tf.Print(x, [x])
def f2():
return tf.Print(z, [z])
op = tf.cond(x>y, f1, f2)
with tf.Session() as sess:
sess.run(op)
我很困惑,tf.Print的输出是3.0
如我们所知,tf.Print(z,[z])仅在评估z
时输出z
的值,但我认为我没有评估过z
}。
另一个问题是关于tf.cond
,它如何将节点添加到图表中,例如如何将tf.Print
添加到图表中,我认为它应该将一些张量与tf.Print
的返回相关联,否则tf.Print
将不会被执行。
我很困惑。
答案 0 :(得分:0)
我认为你可能会让tf.cond
的参数的顺序混乱。电话:
tf.cond(predicate, f, g)
相当于"如果predicate
为真,则评估f
,否则评估g
"
在您的示例中,由于您的谓词x > y
为false,因此会对f2
进行评估
注意强>
由于tensorflow 1.4,tf.cond
将接受关键字参数true_fn
和false_fn
,因此您可以通过编写来避免任何混淆:
tf.cond(predicate, true_fn=f, false_fn=g)
# Or equivalently...
tf.cond(predicate, false_fn=g, true_fn=f)