Tensorflow中RNN的截断反向传播(BPTT)

时间:2017-10-08 02:37:19

标签: tensorflow lstm rnn

https://www.tensorflow.org/tutorials/recurrent#truncated_backpropagation

这里,官方TF文件说,

  

“为了使学习过程易于处理,通常的做法是创建一个'展开'版本的网络,其中包含固定数量(num_steps)的LSTM输入和输出。”

并且该文件需要;

words = tf.placeholder(tf.int32, [batch_size, num_steps])
lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
initial_state = state = tf.zeros([batch_size, lstm.state_size])
for i in range(num_steps):
    output, state = lstm(words[:, i], state)
    # The rest of the code.
    # ...
final_state = state

# After some code lines...
numpy_state = initial_state.eval()
total_loss = 0.0
for current_batch_of_words in words_in_dataset:
    numpy_state, current_loss = session.run([final_state, loss],
        # Initialize the LSTM state from the previous iteration.
        feed_dict={initial_state: numpy_state, words: current_batch_of_words})
    total_loss += current_loss

这些行实现了截断反向传播(BPTT)部分,但我不确定上面的代码部分本质上是必需的。 Tensorflow(我正在使用1.3)是否会自动进行适当的反向传播,即使手写的后支撑实现部分不存在?放置BPTT实施代码是否会显着提高预测准确度?

上面的代码使用从先前时间步长的RNN层返回的状态来馈送下一个时间步的RNNCell。根据官方文档,RNN(GRUCell,LSTMCell ...)层返回输出和状态的元组,但我只用输出构建了我的模型,并没有触摸状态。我只是将输出传递给完全连接层,并重新整形,然后用tf.losses.softmax_cross_entropy计算损失。

0 个答案:

没有答案