在while循环中评估TensorArray时出错

时间:2019-02-11 13:52:50

标签: tensorflow

我已经建立了以下TensorArray

ta = tf.TensorArray(
    dtype=tf.float32,   
    size=0,
    dynamic_size=True,
    element_shape=tf.TensorShape([None, None])
)

,并在ta = ta.write(idx, my_tensor)内称为while_loop

在会话中评估output = ta.stack()张量时,我收到以下错误消息:

  

ValueError:不能将'... / TensorArrayWrite / TensorArrayWriteV3'用作   输入到'... / TensorArrayStack_1 / TensorArraySizeV3',因为   '... / TensorArrayWrite / TensorArrayWriteV3'在while循环中。查看资讯   登录以获取更多详细信息。

我不明白此错误消息,请您帮帮我吗?

更新:可能很难想出一个最小的例子,但这就是我正在做的事情:我正在使用对AttentionWrapperta中的cell_input_fn TensorArray的引用。此回调用于AttentionWrapper的{​​{1}}方法中,其中正在写入另一个名为call的{​​{1}}。因此TensorArray代码不是我设计的,它是TF动态RNN计算alignment_history的一部分。

1 个答案:

答案 0 :(得分:1)

不确定这是否在困扰您,但是您必须确保while_loop函数将张量数组作为输入并发出更新的数组作为输出;并且您必须在while_loop的末尾使用TensorArray的 final 版本:

def fn(ta_old):
  return ta_old.write(...)

ta_final = while_loop(..., body=fn, [tf.TensorArray(...)])

values = ta_final.stack()

特别是,您永远不要在fn()之外访问ta_old。