结合条件和控制依赖

时间:2016-02-06 00:06:49

标签: tensorflow

我正在尝试执行一条条件代码,而这条代码依赖于另一个先执行的操作。这项工作的简单版本,如下所示:

x = tf.Variable(0.)
x_op = tf.assign(x, 1.)

with tf.control_dependencies([x_op]):
    true_fun  = lambda: tf.assign_add(x, 3.)
    false_fun = lambda: tf.constant([])
    pred = tf.constant(True)
    cond_op = control_flow_ops.cond(pred, true_fun, false_fun)

评估cond_op按预期将x设置为4.0的位置。但是,这个更复杂的版本不起作用:

def rest(x): tf.gather(x, tf.range(1, tf.size(x)))

x = tf.Variable([0., 1.])
x_op = tf.assign(x, [0., 1., 2.], validate_shape=False)

with tf.control_dependencies([x_op]):
    true_fun  = lambda: tf.assign(x, rest(x), validate_shape=False)
    false_fun = lambda: tf.constant([])
    pred = tf.constant(True)
    cond_op = control_flow_ops.cond(pred, true_fun, false_fun)

特别是x被分配[1.]而不是[1., 2.]。我想要的逻辑是首先为x分配[0., 1., 2.],然后将然后修剪为[1., 2.]。顺便提一下,这似乎与x更改的大小有关,因为如果在最初的x_op作业x中被分配[1., 2.]而不是[0., 1., 2.],然后评估cond_op结果x被分配[2.],这是正确的行为。即它首先更新为[1., 2.],然后修剪为[2.]

1 个答案:

答案 0 :(得分:5)

请注意,with tf.control_dependencies仅适用于在块内创建的操作。当您在阻止内部调用rest(x)时,您所指的x仍旧是x,它是tf.Variable函数的返回值,它只是{{1}保持变量的初始值。您可以通过调用Tensor来传递新值。这是完整的工作片段:

rest(x_op)