对`tf.cond`的行为感到困惑

时间:2016-05-06 03:54:53

标签: tensorflow

我的图表中需要一个条件控制流程。如果predTrue,则图形应调用更新变量的op,然后返回该变量,否则返回变量不变。简化版本是:

pred = tf.constant(True)
x = tf.Variable([1])
assign_x_2 = tf.assign(x, [2])
def update_x_2():
  with tf.control_dependencies([assign_x_2]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval())

但是,我发现pred=Truepred=False会导致相同的结果y=[2],这意味着当update_x_2未被tf.cond选中时,也会调用分配操作{1}}。怎么解释这个?以及如何解决这个问题?

2 个答案:

答案 0 :(得分:33)

TL; DR:如果您希望tf.cond()在其中一个分支中执行副作用(如作业),则必须创建执行副作用的操作内部您传递给tf.cond()的函数。

tf.cond()的行为有点不直观。由于TensorFlow图中的执行向前流过图,因此在 分支中引用的所有操作必须在评估条件之前执行。这意味着true和false分支都会在tf.assign() op上获得控制依赖关系,因此y始终设置为2,即使pred is为False。

解决方案是在定义true分支的函数内创建tf.assign() op。例如,您可以按如下方式构建代码:

pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
def update_x_2():
  with tf.control_dependencies([tf.assign(x, [2])]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval(feed_dict={pred: False}))  # ==> [1]
  print(y.eval(feed_dict={pred: True}))   # ==> [2]

答案 1 :(得分:3)

pred = tf.constant(False)
x = tf.Variable([1])

def update_x_2():
    assign_x_2 = tf.assign(x, [2])
    with tf.control_dependencies([assign_x_2]):
        return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval())

这将获得[1]的结果。

这个答案与上面的答案完全相同。但我想分享的是你可以把你想要使用的每个操作放在它的分支函数中。因为,鉴于您的示例代码,x函数可以直接使用张量update_x_2