写入TensorArray错误:“无法读取索引”

时间:2018-12-22 03:35:43

标签: python tensorflow

我已经为玩具问题编写了一些TensorFlow代码,该代码使用了tf.while_loop。循环写入tf.TensorArray。在第一次迭代中,我想将10写入数组的第一个元素。然后在后续迭代中,我想将1 + (i-1)写入ith element中。因此,最终数组应类似于:[10, 11, 12, 13, 14, ...]

这是我的代码,用于循环中只有两次迭代的情况:

将tensorflow导入为tf

def loop_body(i, x):
    x = tf.cond(tf.equal(i, 0), lambda: x.write(i, 10), lambda: x.write(i, 1 + x.read(i-1)))
    i = tf.add(i, 1)
    return i, x

num_iterations = 2
iteration_num = tf.constant(0, dtype=tf.int32)
array = tf.TensorArray(dtype=tf.int32, size=num_iterations)
loop_condition = lambda iteration_num, predictions: tf.less(iteration_num, num_iterations)
_, loop_output = tf.while_loop(loop_condition, loop_body, [iteration_num, array])
loop_op = loop_output.stack()

sess = tf.Session()
sess.run(tf.global_variables_initializer())
outputs = sess.run(loop_op)
print(outputs)

运行此命令时,出现以下错误:

Invalid argument: TensorArray TensorArray_0: Could not read index 0 twice because it was cleared after a previous read (perhaps try setting clear_after_read = false?).

但是我不明白这个错误。在迭代的第二个循环中,我应该只读取一次索引0。在第一个循环中,我没有读取索引0,而只是将10分配给数组的第一个元素。

是什么导致此错误?

1 个答案:

答案 0 :(得分:0)

首先,请解释该错误实际上是由stack()引起的。在迭代的第二个循环中,您只读取一次索引0。但是stack()的角色再次被读取为0。换句话说,您的tf.while_loop已正确执行。让我们看看以下步骤。

# loop_op = loop_output.stack() #commented code

outputs = sess.run(loop_output.read(1)) #Modify code
print(outputs)

#print
11

您可以看到已经成功制作了11个。如果将其更改为read(0),它将报告相同的错误。

然后解决该错误的方法是设置clear_after_read=False

array = tf.TensorArray(dtype=tf.int32, size=num_iterations,clear_after_read=False)