Tensorflow while_loop如何返回所有中间值?

时间:2018-12-12 08:17:47

标签: python tensorflow

从tensorflow文档中,我们可以轻松获取循环的最后一个值:

https://www.tensorflow.org/api_docs/python/tf/while_loop

sess = tf.InteractiveSession()
i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: tf.add(i, 1)
r = tf.while_loop(c, b, [i])
print(r.eval())

结果是

10

什么是获取循环中间值的好方法:

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

2 个答案:

答案 0 :(得分:1)

您可以将这些中间值存储在张量中:

import tensorflow as tf

sess = tf.InteractiveSession()
size = 10
values = tf.zeros(size, dtype=tf.int32)
i = tf.constant(0)
c = lambda i, v: tf.less(i, size)
b = lambda i, v: [tf.add(i, 1), v + tf.one_hot(i, size, on_value=i)]
i_, values_ = tf.while_loop(c, b, [i, values])
print(i_.eval())
print(values_.eval())

答案 1 :(得分:1)

使用tf.TensorArray:

sess = tf.InteractiveSession()
N = 10
c = lambda i, _: tf.less_equal(i, N)
b = lambda i, ta: [tf.add(i, 1), ta.write(i, i)]
_, ta = tf.while_loop(c, b, [0, tf.TensorArray(tf.int32, size=N+1)])
print(ta.stack().eval())

或tf.scan(使用tf.while_loop和tf.TensorArray实现):

sess = tf.InteractiveSession()
N = 10
fn = lambda a, x: x
elems = tf.constant(list(range(N+1)))
r = tf.scan(fn, elems)
print(r.eval())