在TensorFlow中初始化LSTM隐藏状态

时间:2018-12-04 09:40:23

标签: tensorflow neural-network deep-learning lstm rnn

有人可以告诉我,在TensorFlow框架中,如何使用用户定义的值初始化LSTM网络的隐藏状态吗?我试图通过提供第一个LSTM单元的特定隐藏状态,将辅助信息合并到LSTM中。

1 个答案:

答案 0 :(得分:0)

您可以通过负责展开图形的函数的参数initial_state传递LSTM的初始隐藏状态。

我假设您将在tensorflow中使用以下某些功能来创建递归神经网络(RNN):tf.nn.dynamic_rnnbidirectional_dynamic_rnntf.nn.static_rnn或{{3} }。 它们都有一个initial_state参数。对于双向RNN,您需要同时传递前向(initial_state_fw)和后向(initial_state_bw)传递的初始状态。

使用tf.nn.dynamic_rnn定义模型的示例:

import tensorflow as tf

batch_size = 32
max_sequence_length = 100
num_features = 128
num_units = 64 

input_sequence = tf.placeholder(tf.float32, shape=[batch_size, max_sequence_length, num_features])
input_sequence_lengths = tf.placeholder(tf.int32, shape=[batch_size])

cell = tf.nn.rnn_cell.LSTMCell(num_units=num_units, state_is_tuple=True)

# Initial states
cell_state = tf.zeros([batch_size, num_units], tf.float32)
hidden_state = tf.placeholder(tf.float32, [batch_size, num_units])
my_initial_state = tf.nn.rnn_cell.LSTMStateTuple(cell_state, hidden_state)

outputs, states = tf.nn.dynamic_rnn(
                    cell=cell,
                    inputs=input_sequence,
                    initial_state=my_initial_state,
                    sequence_length=input_sequence_lengths)

由于我们使用state_is_tuple=True,因此我们需要传递一个初始状态,该初始状态是cell_statehidden_state的元组。 在tf.nn.static_bidirectional_rnn 的文档中,该元组对应于c_statem_stateLSTMCell指出这分别代表单元状态和隐藏状态。

因此,由于我们只想初始化第一个隐藏状态,所以cell_state被初始化为零。