我已经建立了以下TensorArray
:
ta = tf.TensorArray(
dtype=tf.float32,
size=0,
dynamic_size=True,
element_shape=tf.TensorShape([None, None])
)
,并在ta = ta.write(idx, my_tensor)
内称为while_loop
。
在会话中评估output = ta.stack()
张量时,我收到以下错误消息:
ValueError:不能将'... / TensorArrayWrite / TensorArrayWriteV3'用作 输入到'... / TensorArrayStack_1 / TensorArraySizeV3',因为 '... / TensorArrayWrite / TensorArrayWriteV3'在while循环中。查看资讯 登录以获取更多详细信息。
我不明白此错误消息,请您帮帮我吗?
更新:可能很难想出一个最小的例子,但这就是我正在做的事情:我正在使用对AttentionWrapper的ta
中的cell_input_fn
TensorArray的引用。此回调用于AttentionWrapper
的{{1}}方法中,其中正在写入另一个名为call
的{{1}}。因此TensorArray
代码不是我设计的,它是TF动态RNN计算alignment_history
的一部分。
答案 0 :(得分:1)
不确定这是否在困扰您,但是您必须确保while_loop函数将张量数组作为输入并发出更新的数组作为输出;并且您必须在while_loop的末尾使用TensorArray的 final 版本:
def fn(ta_old):
return ta_old.write(...)
ta_final = while_loop(..., body=fn, [tf.TensorArray(...)])
values = ta_final.stack()
特别是,您永远不要在fn()之外访问ta_old。