我正在尝试执行一条条件代码,而这条代码依赖于另一个先执行的操作。这项工作的简单版本,如下所示:
x = tf.Variable(0.)
x_op = tf.assign(x, 1.)
with tf.control_dependencies([x_op]):
true_fun = lambda: tf.assign_add(x, 3.)
false_fun = lambda: tf.constant([])
pred = tf.constant(True)
cond_op = control_flow_ops.cond(pred, true_fun, false_fun)
评估cond_op
按预期将x
设置为4.0
的位置。但是,这个更复杂的版本不起作用:
def rest(x): tf.gather(x, tf.range(1, tf.size(x)))
x = tf.Variable([0., 1.])
x_op = tf.assign(x, [0., 1., 2.], validate_shape=False)
with tf.control_dependencies([x_op]):
true_fun = lambda: tf.assign(x, rest(x), validate_shape=False)
false_fun = lambda: tf.constant([])
pred = tf.constant(True)
cond_op = control_flow_ops.cond(pred, true_fun, false_fun)
特别是x
被分配[1.]
而不是[1., 2.]
。我想要的逻辑是首先为x
分配[0., 1., 2.]
,然后将然后修剪为[1., 2.]
。顺便提一下,这似乎与x
更改的大小有关,因为如果在最初的x_op
作业x
中被分配[1., 2.]
而不是[0., 1., 2.]
,然后评估cond_op
结果x
被分配[2.]
,这是正确的行为。即它首先更新为[1., 2.]
,然后修剪为[2.]
。
答案 0 :(得分:5)
请注意,with tf.control_dependencies
仅适用于在块内创建的操作。当您在阻止内部调用rest(x)
时,您所指的x
仍旧是x
,它是tf.Variable
函数的返回值,它只是{{1}保持变量的初始值。您可以通过调用Tensor
来传递新值。这是完整的工作片段:
rest(x_op)