批次之间传递LSTM状态的最佳方法

时间:2018-03-03 08:07:34

标签: python tensorflow lstm

我试图找到在批次之间传递LSTM状态的最佳方法。我搜索了所有内容,但我找不到当前实现的解决方案。想象一下,我有类似的东西:

cells = [rnn.LSTMCell(size) for size in [256,256]
cells = rnn.MultiRNNCell(cells, state_is_tuple=True)
init_state = cells.zero_state(tf.shape(x_hot)[0], dtype=tf.float32)
net, new_state = tf.nn.dynamic_rnn(cells, x_hot, initial_state=init_state ,dtype=tf.float32)

现在我想有效地传递每个批次中的new_state,因此不将其存储回内存,然后使用feed_dict重新输入到tf。更准确地说,我发现的所有解决方案都使用sess.run来评估new_statefeed-dict以将其传递给init_state。有没有办法在没有使用feed-dict的瓶颈的情况下这样做?

我认为我应该以某种方式使用tf.assign,但文档不完整,我找不到任何解决方法。

我要感谢所有提前提出要求的人。

干杯,

Francesco Saverio

我在堆栈溢出中找到的所有其他答案适用于旧版本或使用' feed-dict'传递新状态的方法。例如:

1)TensorFlow: Remember LSTM state for next batch (stateful LSTM)这可以通过使用' feed-dict'提供州占位符,我想避免那个

2)Tensorflow - LSTM state reuse within batch这不适用于状态turple

3)Saving LSTM RNN state between runs in Tensorflow同样在这里

1 个答案:

答案 0 :(得分:3)

LSTMStateTuple只不过是输出和隐藏状态的元组。 tf.assign创建一个操作,在运行时,将存储在张量中的值分配给变量(如果您有特定问题,请询问以便可以改进文档)。您可以使用tf.assign的解决方案,使用元组的c属性从元组中检索隐藏状态张量(假设您想要隐藏状态) - new_state.c

以下是有关玩具问题的完整自包含示例:https://gist.github.com/iganichev/632b425fed0263d0274ec5b922aa3b2f