如何在while_loop中使用TensorArray.stack

时间:2018-09-13 17:14:56

标签: python-3.x tensorflow

我的问题是我需要在while_loop的每次迭代中使用TensorArray的累积元素,然后向TensorArray中添加另一个依赖于先前元素的元素。

我得到的错误是:

InvalidArgumentError (see above for traceback): TensorArray TensorArray_0: Could not write to TensorArray index 1 because it has already been read.

MVCE

import tensorflow as tf
n_iter = 4


def _body(j, ta):
    X = ta.stack()
    out = tf.ones((2,2),dtype=tf.float32) * tf.reduce_sum(X)

    return (j+1, ta.write(j+1,out))

loop_vars = [
    tf.constant(0), # iter
    tf.TensorArray(tf.float32, n_iter, dynamic_size=True, infer_shape=False,element_shape=(2,2),clear_after_read=False)
            ]


_, ta = tf.while_loop(lambda j, _: j < n_iter, _body, loop_vars=loop_vars)

tf.Session().run(ta.stack())

系统

TF 1.10

0 个答案:

没有答案