tensorflow AttentionCellWrapper实现似乎不对

时间:2017-06-16 02:33:26

标签: python tensorflow deep-learning attention-model

我正在观看Tensorflow中的AttentionCellWrapper实现,但我很困惑。注意解码器使用源序列的隐藏状态。但是在下面的Tensorflow实现中,因为注意值也是源序列的隐藏状态切片并且每次都附加单元输出。我不确定我是否正确。有人能搞清楚吗?

def call(self, inputs, state):
  """Long short-term memory cell with attention (LSTMA)."""
  if self._state_is_tuple:
    state, attns, attn_states = state
  else:
    states = state
    state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
    attns = array_ops.slice(
      states, [0, self._cell.state_size], [-1, self._attn_size])
    attn_states = array_ops.slice(
      states, [0, self._cell.state_size + self._attn_size],
      [-1, self._attn_size * self._attn_length])
  attn_states = array_ops.reshape(attn_states,
                                [-1, self._attn_length, self._attn_size])
  input_size = self._input_size
  if input_size is None:
    input_size = inputs.get_shape().as_list()[1]
  inputs = _linear([inputs, attns], input_size, True)
  lstm_output, new_state = self._cell(inputs, state)
  if self._state_is_tuple:
    new_state_cat = array_ops.concat(nest.flatten(new_state), 1)
  else:
    new_state_cat = new_state
  new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
  with vs.variable_scope("attn_output_projection"):
    output = _linear([lstm_output, new_attns], self._attn_size, True)
  new_attn_states = array_ops.concat(
    [new_attn_states, array_ops.expand_dims(output, 1)], 1)
  new_attn_states = array_ops.reshape(
    new_attn_states, [-1, self._attn_length * self._attn_size])
  new_state = (new_state, new_attns, new_attn_states)
  if not self._state_is_tuple:
    new_state = array_ops.concat(list(new_state), 1)
  return output, new_state

def _attention(self, query, attn_states):
  conv2d = nn_ops.conv2d
  reduce_sum = math_ops.reduce_sum
  softmax = nn_ops.softmax
  tanh = math_ops.tanh

  with vs.variable_scope("attention"):
    k = vs.get_variable(
      "attn_w", [1, 1, self._attn_size, self._attn_vec_size])
    v = vs.get_variable("attn_v", [self._attn_vec_size])
    hidden = array_ops.reshape(attn_states,
                             [-1, self._attn_length, 1, self._attn_size])
    hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME")
    y = _linear(query, self._attn_vec_size, True)
    y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size])
    s = reduce_sum(v * tanh(hidden_features + y), [2, 3])
    a = softmax(s)
    d = reduce_sum(
      array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2])
    new_attns = array_ops.reshape(d, [-1, self._attn_size])
    new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1])
    return new_attns, new_attn_states

1 个答案:

答案 0 :(得分:0)

simple architecture of the cell with AttentionCellWrapper

我很确定tf.contrib.rnn.AttentionCellWrapper的工作方式如图1所示,并且您是正确的。

在开始序列的隐藏状态(如图1中的M)被初始化为0之前。 每个时间步,隐藏状态都会对其第一列进行切片并附加输出,以便隐藏状态可以将1保存到t-1时间步的输出。