Tensorflow:如何使用dynamic_rnn从LSTMCell获取中间单元状态(c)?

时间:2017-12-11 00:37:53

标签: python machine-learning tensorflow lstm rnn

默认情况下,函数dynamic_rnn仅为每个时间点输出隐藏状态(称为m),可以按如下方式获取:

cell = tf.contrib.rnn.LSTMCell(100)
rnn_outputs, _ = tf.nn.dynamic_rnn(cell,
                                   inputs=inputs,
                                   sequence_length=sequence_lengths,
                                   dtype=tf.float32)

还有办法获得中间(非最终)单元状态(c)吗?

tensorflow撰稿人mentions可以使用单元格包装器完成:

class Wrapper(tf.nn.rnn_cell.RNNCell):
  def __init__(self, inner_cell):
     super(Wrapper, self).__init__()
     self._inner_cell = inner_cell
  @property
  def state_size(self):
     return self._inner_cell.state_size
  @property
  def output_size(self):
    return (self._inner_cell.state_size, self._inner_cell.output_size)
  def call(self, input, state)
    output, next_state = self._inner_cell(input, state)
    emit_output = (next_state, output)
    return emit_output, next_state

然而,它似乎不起作用。有什么想法吗?

2 个答案:

答案 0 :(得分:2)

建议的解决方案适合我,但Layer.call方法规范更为通用,因此以下Wrapper应该对API更改更加健壮。你的意思是:

class Wrapper(tf.nn.rnn_cell.RNNCell):
  def __init__(self, inner_cell):
     super(Wrapper, self).__init__()
     self._inner_cell = inner_cell

  @property
  def state_size(self):
     return self._inner_cell.state_size

  @property
  def output_size(self):
    return (self._inner_cell.state_size, self._inner_cell.output_size)

  def call(self, input, *args, **kwargs):
    output, next_state = self._inner_cell(input, *args, **kwargs)
    emit_output = (next_state, output)
    return emit_output, next_state

以下是测试:

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
basic_cell = Wrapper(tf.nn.rnn_cell.LSTMCell(num_units=n_neurons, state_is_tuple=False))
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)
print(outputs, states)

X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [0, 0, 0]], # instance 1
  [[6, 7, 8], [6, 5, 4]], # instance 2
  [[9, 0, 1], [3, 2, 1]], # instance 3
])

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outputs_val = outputs[0].eval(feed_dict={X: X_batch})
  print(outputs_val)

返回outputs(?, 2, 10)(?, 2, 5)张量的元组,它们都是LSTM状态和输出。请注意,我使用"毕业"版本LSTMCell,来自tf.nn.rnn_cell包,而不是tf.contrib.rnn。另请注意state_is_tuple=True以避免处理LSTMStateTuple

答案 1 :(得分:0)

根据Maxim的想法,我最终得到了以下解决方案:

class StatefulLSTMCell(LSTMCell):
    def __init__(self, *args, **kwargs):
        super(StatefulLSTMCell, self).__init__(*args, **kwargs)

    @property
    def output_size(self):
        return (self.state_size, super(StatefulLSTMCell, self).output_size)

    def call(self, input, state):
        output, next_state = super(StatefulLSTMCell, self).call(input, state)
        emit_output = (next_state, output)
        return emit_output, next_state