tf.nn.dynamic_rnn的时间反向传播,用于顺序输入(来自批处理)

时间:2018-03-14 08:52:00

标签: tensorflow deep-learning lstm recurrent-neural-network rnn

我有这样的代码:

lstm_cell = tf.contrib.rnn.BasicLSTMCell(256, state_is_tuple = True)
c_in = tf.placeholder(tf.float32, [1, lstm_cell.state_size.c], "c_in")
h_in = tf.placeholder(tf.float32, [1, lstm_cell.state_size.h], "h_in")
rnn_state_in = (c_in, h_in)
rnn_in = tf.expand_dims(previous_layer, [0])
sequence_length = #size of my batch
rnn_state_in = tf.contrib.rnn.LSTMStateTuple(c_in, h_in)
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(lstm_cell,
                                                    rnn_in,
                                                    initial_state = rnn_state_in,
                                                    sequence_length = sequence_length,
                                                    time_major = False)
lstm_c, lstm_h = lstm_state
rnn_out = tf.reshape(lstm_outputs, [-1, 256])

在这里,我使用dynamic_rnn来模拟批处理的时间步长。 每次前进,我都可以获得lstm_c, lstm_h,我可以在外面的任何地方存储。

因此,假设我已经在网络中的序列中对N个项目进行了正向传递,并且具有dynamic_rnn提供的最终单元状态和隐藏状态。现在,为了执行反向传播,我应该对LSTM进行什么输入?

默认情况下,会在dynamic_rnn的时间步骤中发生反向支持吗?

(比如,时间步数= batch_size = N)

因此,我可以将输入提供为:

sess.run(_train_op, feed_dict = {_state: np.vstack(batch_states),
                                            ...
                                            c_in: prev_rnn_state[0],
                                            h_in: prev_rnn_state[1]
})

(其中prev_rnn_statecell state, hidden state的元组,我从前一批的前向传播中获取dynamic_rnn。)

或者我是否明确地在时间序列中展开LSTM层并通过提供一个单元格状态的向量并在前一个时间序列中收集隐藏来训练它?

1 个答案:

答案 0 :(得分:1)

是的,backprop在dynamic_rnn的时间步骤中发生。

但是,我认为您正在研究inputs的{​​{1}}参数。它应该是dynamic_rnn形状。当您使用类似该形状的输入调用dynamic_rnn时,它会使用您[batch_size, max_time, ...]提供的初始状态调用lstm_cell max_time次。

请记住,每次步骤rnn_state_in都会自动从上一个时间步骤中获取c和h状态。所以你不必每次都在sess.run(..)里面喂它们。您只需输入输入。

当你使用lstm的最终状态(或所有状态)计算损失并使用像SGD或adam这样的优化器时,将计算所有时间步的backprop。