我想在tensorflow最新版本(1.2)中可视化注意力得分。我在contrib.seq2seq中使用AttentionWrapper来构建一个RNNCell,使用BasicDecoder作为解码器,然后使用dynamic_decode()逐步生成输出。
如何获取所有步骤的注意力?谢谢!
答案 0 :(得分:4)
您可以通过在AttentionWrapper定义中设置 alignment_history = True 标志来访问注意力量。
以下是示例:
# Define attention mechanism
attn_mech = tf.contrib.seq2seq.LuongMonotonicAttention(
num_units = attention_unit_size, memory = decoder_outputs,
memory_sequence_length = input_lengths)
# Define attention cell
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
cell = decoder_cell, attention_mechanism = attn_mech,
alignment_history=True)
# Define train helper
train_helper = tf.contrib.seq2seq.TrainingHelper(
inputs = encoder_inputs,
sequence_length = input_lengths)
# Define decoder
decoder = tf.contrib.seq2seq.BasicDecoder(
cell = attn_cell,
helper = train_helper, initial_state=decoder_initial_state)
# Dynamic decoding
dec_outputs, dec_states, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
然后在会话中,您可以访问权重,如下所示:
with tf.Session() as sess:
...
alignments = sess.run(dec_states.alignment_history.stack(), feed_dict)
最后,您可以像这样想象注意(对齐):
def plot_attention(attention_map, input_tags = None, output_tags = None):
attn_len = len(attention_map)
# Plot the attention_map
plt.clf()
f = plt.figure(figsize=(15, 10))
ax = f.add_subplot(1, 1, 1)
# Add image
i = ax.imshow(attention_map, interpolation='nearest', cmap='Blues')
# Add colorbar
cbaxes = f.add_axes([0.2, 0, 0.6, 0.03])
cbar = f.colorbar(i, cax=cbaxes, orientation='horizontal')
cbar.ax.set_xlabel('Alpha value (Probability output of the "softmax")', labelpad=2)
# Add labels
ax.set_yticks(range(attn_len))
if output_tags != None:
ax.set_yticklabels(output_tags[:attn_len])
ax.set_xticks(range(attn_len))
if input_tags != None:
ax.set_xticklabels(input_tags[:attn_len], rotation=45)
ax.set_xlabel('Input Sequence')
ax.set_ylabel('Output Sequence')
# add grid and legend
ax.grid()
plt.show()
# input_tags - word representation of input sequence, use None to skip
# output_tags - word representation of output sequence, use None to skip
# i - index of input element in batch
plot_attention(alignments[:, i, :], input_tags, output_tags)