Tensorflow dequeue within while_loop

时间:2017-04-13 14:55:00

标签: while-loop tensorflow

consider this simple example:

l = [1,2,324,3,12,1,2,3]
q = tf.train.input_producer(l, shuffle=False)
x = q.dequeue()

t = tf.TensorArray(dtype=tf.int32, size=5, dynamic_size=True, clear_after_read=True)

_, t = tf.while_loop(cond=lambda i, a: tf.less(i, 5, name='less_op'),
body=lambda i, a: [i+1, a.write(i, [x])],
loop_vars=[0, t])

it outputs [1 1 1 1 1], since the dequeue() is called only once. Please tell me how to trigger the dequeue operation on each iteration.

Thank you!

Cheers, Kris

1 个答案:

答案 0 :(得分:1)

问题出现是因为tf.while_loop()正文捕获 张量x作为循环不变量,而你希望在循环内部执行出列的副作用

解决方案是将呼叫移至身体内的q.dequeue(),如下所示:

import tensorflow as tf

l = [1, 2, 324, 3, 12, 1, 2, 3]
q = tf.train.input_producer(l, shuffle=False)
t = tf.TensorArray(dtype=tf.int32, size=5, dynamic_size=True, clear_after_read=True)

# N.B. set `parallel_iterations=1` to ensure that values are dequeued in a
# deterministic FIFO order.
_, t = tf.while_loop(cond=lambda i, a: tf.less(i, 5, name='less_op'),
                     body=lambda i, a: [i+1, a.write(i, [q.dequeue()])],
                     loop_vars=[0, t],
                     parallel_iterations=1)

result = t.stack()

sess = tf.Session()
tf.train.start_queue_runners(sess)
print(sess.run(result))  # ==> '[[1], [2], [324], [3], [12]]'
print(sess.run(result))  # ==> '[[1], [2], [3], [1], [2]]'