从AttentionWrapper获取上下文向量

时间:2019-02-11 09:59:16

标签: tensorflow seq2seq attention-model

我需要从应用于Seq2Seq模型的注意力机制中提取上下文向量。

我的第一个猜测是我可以在dynamic_decode的输出中找到它们

decoder_cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(self.num_units, self.dropout) for _ in range(self.num_layers)])

attention_mechanism = tf.contrib.seq2seq.LuongAttention(self.num_units_attention, encoder_output, memory_sequence_length=x_lengths)

decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism, attention_layer_size=self.num_units_attention, alignment_history=True)

decoder_initial_state = decoder_cell.zero_state(batch_size, tf.float32).clone(cell_state=encoder_state)

decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, decoder_initial_state, output_layer=self.projection_layer)

outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder,maximum_iterations=maximum_iterations)

context_vectors = outputs.rnn_outputs

但是我意识到这些是计算出的注意力向量,而不是上下文向量。

AttentionWrapper中存储的上下文向量在哪里?

0 个答案:

没有答案