在while循环中读取并分配给TensorArray

时间:2018-12-18 18:14:55

标签: python tensorflow

在TensorFlow中,我有一个tf.while_loop,其中涉及到TensorArray的使用。我写了一个最小的玩具示例来演示我遇到的一个问题。

对于每个循环,我想读取此数组中一个元素的值,将其添加到张量,然后将结果分配给数组的另一个元素。 while循环的body参数定义为以下函数:

def loop_body(i, x, y):
    x = x.write(i, y + x.gather(indices=[i-1])))
    return i, x

ixy初始化为:

i = tf.constant(1, dtype=tf.int32)
x = tf.TensorArray(dtype=tf.float32, size=10)
x = x.write(0, [0, 0, 0])
y = tf.constant([1, 2, 3], dtype=tf.float32)

现在,当我运行代码并执行while循环时,出现以下错误:

ValueError: Inconsistent shapes: saw (?, 3) but expected (3,) (and infer_shape=True)

为什么x.gather()的形状不是(3,)?我应该怎么做?

1 个答案:

答案 0 :(得分:2)

文档已经描述了tf.TensorArray.gather()将以压缩Tensor的形式返回TensorArray 中的选定值

  

返回:

     

通过索引选择的TensorArray中的,打包成一个张量。

因此您将获得(?,3)的形状。您可以更改它:

x = x.write(i, y + x.gather(indices=[i-1])[0])
# or
x = x.write(i, y + x.read(i-1))

此外,您的代码中还有一些错误。我将在下面给出一个完整的示例。

import tensorflow as tf

def condition(i, x,y):
    return tf.less(i, 10)

def loop_body(i, x,y):
    x = x.write(i, y + x.gather(indices=[i - 1])[0])
    #or
    # x = x.write(i, y + x.read(i-1))
    return i+1, x, y

i = tf.constant(1)
x = tf.TensorArray(dtype=tf.float32,size=1, dynamic_size=True,clear_after_read=False)
x = x.write(0, [0., 0., 0.])
y = tf.constant([1, 2, 3], dtype=tf.float32)

i, x, y = tf.while_loop(condition, loop_body, loop_vars=[i,x,y])
x = x.stack()

with tf.Session():
    print(i.eval())
    print(x.eval())

#print
10
[[ 0.  0.  0.]
 [ 1.  2.  3.]
 [ 2.  4.  6.]
 [ 3.  6.  9.]
 [ 4.  8. 12.]
 [ 5. 10. 15.]
 [ 6. 12. 18.]
 [ 7. 14. 21.]
 [ 8. 16. 24.]
 [ 9. 18. 27.]]