我试图找到在批次之间传递LSTM状态的最佳方法。我搜索了所有内容,但我找不到当前实现的解决方案。想象一下,我有类似的东西:
cells = [rnn.LSTMCell(size) for size in [256,256]
cells = rnn.MultiRNNCell(cells, state_is_tuple=True)
init_state = cells.zero_state(tf.shape(x_hot)[0], dtype=tf.float32)
net, new_state = tf.nn.dynamic_rnn(cells, x_hot, initial_state=init_state ,dtype=tf.float32)
现在我想有效地传递每个批次中的new_state
,因此不将其存储回内存,然后使用feed_dict
重新输入到tf。更准确地说,我发现的所有解决方案都使用sess.run
来评估new_state
和feed-dict
以将其传递给init_state
。有没有办法在没有使用feed-dict
的瓶颈的情况下这样做?
我认为我应该以某种方式使用tf.assign
,但文档不完整,我找不到任何解决方法。
我要感谢所有提前提出要求的人。
干杯,
Francesco Saverio
我在堆栈溢出中找到的所有其他答案适用于旧版本或使用' feed-dict'传递新状态的方法。例如:
1)TensorFlow: Remember LSTM state for next batch (stateful LSTM)这可以通过使用' feed-dict'提供州占位符,我想避免那个
2)Tensorflow - LSTM state reuse within batch这不适用于状态turple
答案 0 :(得分:3)
LSTMStateTuple
只不过是输出和隐藏状态的元组。 tf.assign
创建一个操作,在运行时,将存储在张量中的值分配给变量(如果您有特定问题,请询问以便可以改进文档)。您可以使用tf.assign
的解决方案,使用元组的c
属性从元组中检索隐藏状态张量(假设您想要隐藏状态) - new_state.c
以下是有关玩具问题的完整自包含示例:https://gist.github.com/iganichev/632b425fed0263d0274ec5b922aa3b2f