我正在使用Tensorflow NN,它使用LSTM跟踪参数(时间序列数据回归问题)。一批训练数据包含连续观察的batch_size。我想使用LSTM状态作为下一个样本的输入。所以,如果我有一批数据观察,我想将第一次观测的状态作为第二次观测的输入,依此类推。下面我将lstm状态定义为size = batch_size的张量。我想在批次中重用状态:
state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False)
cell = tf.nn.rnn_cell.BasicLSTMCell(100)
output, curr_state = tf.nn.rnn(cell, data, initial_state=state)
在API中有一个tf.nn.state_saving_rnn,但文档有点模糊。 我的问题:如何在培训批次中重复使用curr_state 。
答案 0 :(得分:1)
您基本上就在那里,只需要使用state
更新curr_state
:
state_update = tf.assign(state, curr_state)
然后,确保在run
本身上调用state_update
或将state_update
作为依赖项的操作调用,否则分配将不会实际发生。例如:
with tf.control_dependencies([state_update]):
model_output = ...
正如评论中所建议的那样,RNN的典型情况是你有一个批处理,其中第一维(0)是序列数,第二维(1)是每个序列的最大长度(如果你通过time_major=True
当您构建RNN时,这两个被交换)。理想情况下,为了获得良好的性能,您可以将多个序列堆叠到一个批处理中,然后按时间分割该批处理。但这确实是一个不同的主题。