如何理解使用tf.Print的tf.cond?

时间:2017-12-12 08:39:07

标签: tensorflow

查看代码:

import tensorflow as tf

x = tf.constant(1.0)
y = tf.constant(2.0)
z = tf.constant(3.0)
def f1():
    return tf.Print(x, [x])

def f2():
    return tf.Print(z, [z])
op = tf.cond(x>y, f1, f2)
with tf.Session() as sess:
    sess.run(op)

我很困惑,tf.Print的输出是3.0

如我们所知,tf.Print(z,[z])仅在评估z时输出z的值,但我认为我没有评估过z }。

另一个问题是关于tf.cond,它如何将节点添加到图表中,例如如何将tf.Print添加到图表中,我认为它应该将一些张量与tf.Print的返回相关联,否则tf.Print将不会被执行。

我很困惑。

1 个答案:

答案 0 :(得分:0)

我认为你可能会让tf.cond的参数的顺序混乱。电话:

tf.cond(predicate, f, g)

相当于"如果predicate为真,则评估f,否则评估g"

在您的示例中,由于您的谓词x > y为false,因此会对f2进行评估

注意

由于tensorflow 1.4,tf.cond将接受关键字参数true_fnfalse_fn,因此您可以通过编写来避免任何混淆:

tf.cond(predicate, true_fn=f, false_fn=g)

# Or equivalently...
tf.cond(predicate, false_fn=g, true_fn=f)