什么是"来源隐藏状态"在注意机制中提到?

时间:2018-01-23 04:01:47

标签: machine-learning nlp deep-learning sequence-to-sequence attention-model

注意力量计算如下:

enter image description here

我想知道h_s指的是什么。

在张量流代码中,编码器RNN返回元组:

encoder_outputs, encoder_state = tf.nn.dynamic_rnn(...)

我认为h_s应该是encoder_state,但github/nmt会给出不同的答案?

# attention_states: [batch_size, max_time, num_units]
attention_states = tf.transpose(encoder_outputs, [1, 0, 2])

# Create an attention mechanism
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
    num_units, attention_states,
    memory_sequence_length=source_sequence_length)

我误解了代码吗?或者h_s实际上是指encoder_outputs

1 个答案:

答案 0 :(得分:2)

该公式可能来自this post,因此我将使用同一帖子中的NN图片:

nn

此处,h-bar(s)是来自编码器(最后一层)的所有蓝色隐藏状态,而h(t)是来自的当前红色隐藏状态解码器(也是最后一层)。一张图片t=0,您可以看到哪些块用虚线箭头连接到注意重量。 score函数通常是其中之一:

formula

Tensorflow关注机制与此图片相匹配。理论上,在大多数情况下,单元输出 是隐藏状态(一个例外是LSTM单元,其中输出是状态的短期部分,即使在这种情况下输出也适合于注意机制)。实际上,当输入用零填充时,tensorflow的encoder_stateencoder_outputs不同:状态从前一个单元状态传播而输出为零。显然,你不想关注尾随零,因此对这些单元格有h-bar(s)是有意义的。

所以encoder_outputs正好是蓝色方块向上的箭头。稍后在代码中,attention_mechanism连接到每个decoder_cell,以便其输出通过上下文向量到达图片上的黄色块。

decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
    decoder_cell, attention_mechanism,
    attention_layer_size=num_units)