tf.scatter_update()如何在while_loop()内部工作

时间:2016-05-06 02:03:04

标签: tensorflow

我正在尝试使用Dashboard.const_get("Parent") == Dashboard::Parent # => true 更新tf.Variable内的tf.while_loop()。但是,结果是初始值而不是更新的值。以下是我要做的示例代码:

tf.scatter_update()

结果是:from __future__ import print_function import tensorflow as tf def cond(sequence_len, step): return tf.less(step,sequence_len) def body(sequence_len, step): begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0)) begin = tf.scatter_update(begin,1,step,use_locking=None) tf.get_variable_scope().reuse_variables() return (sequence_len, step+1) with tf.Graph().as_default(): sess = tf.Session() step = tf.constant(0) sequence_len = tf.constant(10) _,step, = tf.while_loop(cond, body, [sequence_len, step], parallel_iterations=10, back_prop=True, swap_memory=False, name=None) begin = tf.get_variable("begin",[3],dtype=tf.int32) init = tf.initialize_all_variables() sess.run(init) print(sess.run([begin,step])) 。但是,我认为结果应该是[array([0, 0, 0], dtype=int32), 10]。我在这里做错了吗?

1 个答案:

答案 0 :(得分:6)

这里的问题是循环体中没有任何东西取决于你的tf.scatter_update() op,所以它永远不会被执行。使其工作的最简单方法是在更新中添加对返回值的控制依赖性:

def body(sequence_len, step): 
    begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0))
    begin = tf.scatter_update(begin, 1, step, use_locking=None)
    tf.get_variable_scope().reuse_variables()

    with tf.control_dependencies([begin]):
        return (sequence_len, step+1)

请注意,此问题并非TensorFlow中的循环所特有。如果您刚刚定义了名为tf.scatter_update()的{​​{1}}操作,但在其上调用begin或依赖于它的内容,那么更新将不会发生。当你使用tf.while_loop()时,无法直接运行循环体中定义的操作,因此获得副作用的最简单方法是添加控件依赖项。

请注意,最终结果为sess.run():每次迭代都会将当前步骤分配给[0, 9, 0],而在最后一次迭代中,当前步骤的值为begin[1](条件为false)当9)。