Tensorflow - 批量内的LSTM状态重用

时间:2017-02-09 10:06:30

标签: tensorflow lstm recurrent-neural-network

我正在使用Tensorflow NN,它使用LSTM跟踪参数(时间序列数据回归问题)。一批训练数据包含连续观察的batch_size。我想使用LSTM状态作为下一个样本的输入。所以,如果我有一批数据观察,我想将第一次观测的状态作为第二次观测的输入,依此类推。下面我将lstm状态定义为size = batch_size的张量。我想在批次中重用状态

state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False)
cell = tf.nn.rnn_cell.BasicLSTMCell(100)
output, curr_state = tf.nn.rnn(cell, data, initial_state=state) 

在API中有一个tf.nn.state_saving_rnn,但文档有点模糊。 我的问题:如何在培训批次中重复使用curr_state

1 个答案:

答案 0 :(得分:1)

您基本上就在那里,只需要使用state更新curr_state

state_update = tf.assign(state, curr_state)

然后,确保在run本身上调用state_update或将state_update作为依赖项的操作调用,否则分配将不会实际发生。例如:

with tf.control_dependencies([state_update]):
    model_output = ...

正如评论中所建议的那样,RNN的典型情况是你有一个批处理,其中第一维(0)是序列数,第二维(1)是每个序列的最大长度(如果你通过time_major=True当您构建RNN时,这两个被交换)。理想情况下,为了获得良好的性能,您可以将多个序列堆叠到一个批处理中,然后按时间分割该批处理。但这确实是一个不同的主题。