在TensorFlow中,我有一个tf.while_loop
,其中涉及到TensorArray
的使用。我写了一个最小的玩具示例来演示我遇到的一个问题。
对于每个循环,我想读取此数组中一个元素的值,将其添加到张量,然后将结果分配给数组的另一个元素。 while循环的body
参数定义为以下函数:
def loop_body(i, x, y):
x = x.write(i, y + x.gather(indices=[i-1])))
return i, x
i
,x
和y
初始化为:
i = tf.constant(1, dtype=tf.int32)
x = tf.TensorArray(dtype=tf.float32, size=10)
x = x.write(0, [0, 0, 0])
y = tf.constant([1, 2, 3], dtype=tf.float32)
现在,当我运行代码并执行while循环时,出现以下错误:
ValueError: Inconsistent shapes: saw (?, 3) but expected (3,) (and infer_shape=True)
为什么x.gather()
的形状不是(3,)?我应该怎么做?
答案 0 :(得分:2)
文档已经描述了tf.TensorArray.gather()
将以压缩Tensor的形式返回TensorArray 中的选定值。
返回:
通过索引选择的TensorArray中的,打包成一个张量。
因此您将获得(?,3)的形状。您可以更改它:
x = x.write(i, y + x.gather(indices=[i-1])[0])
# or
x = x.write(i, y + x.read(i-1))
此外,您的代码中还有一些错误。我将在下面给出一个完整的示例。
import tensorflow as tf
def condition(i, x,y):
return tf.less(i, 10)
def loop_body(i, x,y):
x = x.write(i, y + x.gather(indices=[i - 1])[0])
#or
# x = x.write(i, y + x.read(i-1))
return i+1, x, y
i = tf.constant(1)
x = tf.TensorArray(dtype=tf.float32,size=1, dynamic_size=True,clear_after_read=False)
x = x.write(0, [0., 0., 0.])
y = tf.constant([1, 2, 3], dtype=tf.float32)
i, x, y = tf.while_loop(condition, loop_body, loop_vars=[i,x,y])
x = x.stack()
with tf.Session():
print(i.eval())
print(x.eval())
#print
10
[[ 0. 0. 0.]
[ 1. 2. 3.]
[ 2. 4. 6.]
[ 3. 6. 9.]
[ 4. 8. 12.]
[ 5. 10. 15.]
[ 6. 12. 18.]
[ 7. 14. 21.]
[ 8. 16. 24.]
[ 9. 18. 27.]]