当state_is_tuple = True时,如何设置TensorFlow RNN状态?

时间:2016-08-24 00:22:34

标签: python machine-learning tensorflow

我写了RNN language model using TensorFlow。该模型实现为RNN类。图结构是在构造函数中构建的,而RNN.trainRNN.test方法则运行它。

我希望能够在移动到训练集中的新文档时重置RNN状态,或者当我想在训练期间运行验证集时。我这样做是通过管理训练循环内的状态,通过提要字典将其传递到图表中。

在构造函数中,我像这样定义RNN

    cell = tf.nn.rnn_cell.LSTMCell(hidden_units)
    rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)
    self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)
    self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")
    self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,
                                                  initial_state=self.state)

训练循环如下所示

 for document in document:
     state = session.run(self.reset_state)
     for x, y in document:
          _, state = session.run([self.train_step, self.next_state], 
                                 feed_dict={self.x:x, self.y:y, self.state:state})

xy是文档中的批量培训数据。这个想法是我在每个批次之后传递最新状态,除非我开始一个新文档,当我通过运行self.reset_state将状态归零时。

这一切都有效。现在我想更改我的RNN以使用推荐的state_is_tuple=True。但是,我不知道如何通过提要字典传递更复杂的LSTM状态对象。另外,我不知道在构造函数中传递给self.state = tf.placeholder(...)行的参数是什么。

这里的正确策略是什么? dynamic_rnn可用的代码或文档仍然不多。

TensorFlow问题26952838似乎相关。

WILDML上的blog post解决了这些问题,但没有直接说出答案。

另见TensorFlow: Remember LSTM state for next batch (stateful LSTM)

2 个答案:

答案 0 :(得分:21)

Tensorflow占位符的一个问题是你只能用Python列表或Numpy数组(我认为)来提供它。因此,您无法在LSTMStateTuple的元组中的运行之间保存状态。

我通过将状态保存在这样的张量中来解决这个问题

initial_state = np.zeros((num_layers, 2, batch_size, state_size))

LSTM层中有两个组件,单元状态隐藏状态,这就是" 2"来自。 (这篇文章很棒:https://arxiv.org/pdf/1506.00019.pdf

构建图形时,解压缩并创建元组状态,如下所示:

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
         [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
          for idx in range(num_layers)]
)

然后你以通常的方式获得新状态

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)

它不应该是这样的......也许他们正在研究解决方案。

答案 1 :(得分:2)

以RNN状态进行馈送的一种简单方法是单独输入状态元组的两个组件。

# Constructing the graph
self.state = rnn_cell.zero_state(...)
self.output, self.next_state = tf.nn.dynamic_rnn(
    rnn_cell,
    self.input,
    initial_state=self.state)

# Running with initial state
output, state = sess.run([self.output, self.next_state], feed_dict={
    self.input: input
})

# Running with subsequent state:
output, state = sess.run([self.output, self.next_state], feed_dict={
    self.input: input,
    self.state[0]: state[0],
    self.state[1]: state[1]
})