Tensorflow Stacked GRU Cell

时间:2017-08-08 19:49:38

标签: tensorflow rnn gated-recurrent-unit

我正在尝试在张量流中实现具有MultiRNNCell和GRUCell的堆叠RNN。

从GRUCell的默认实现中可以看出"输出"和#34;州" GRUCell的内容是相同的:

class GRUCell(RNNCell)
  ...
  def call(self, inputs, state):
    ...
    new_h = u * state + (1 - u) * c
    return new_h, new_h

这是有道理的,因为它与定义一致。但是,当我们将它们与MultiRNNCell堆叠时,它被定义为:

class MultiRNNCell(RNNCell):
  ...
  def call(self, inputs, state):
    ...
    cur_inp = inputs
    new_states = []
    for i, cell in enumerate(self._cells):
      # set cur_state = states[i] ...
      cur_inp, new_state = cell(cur_inp, cur_state)
      new_states.append(new_state)
    return cur_inp, new_states

(代码已经过浓缩以突出显示相关位)

在这种情况下,任何不是第一个的GRUCell都会收到相同的"输入"和"州"。基本上,它在单个输入上运行,这是前一层的输出。

由于复位/更新门的值取决于两个输入值(输入/状态)的比较,因此最终不会成为冗余操作,这最终会导致值直接传递从第一层?

似乎MultiRNNCell的这种架构主要是考虑到LSTM Cell而设计的,因为它们将输出和单元状态分开,但不适合GRU Cell。

0 个答案:

没有答案