如何从AttentionWrapper可视化注意力量

时间:2017-06-18 08:35:31

标签: tensorflow visualization

我想在tensorflow最新版本(1.2)中可视化注意力得分。我在contrib.seq2seq中使用AttentionWrapper来构建一个RNNCell,使用BasicDecoder作为解码器,然后使用dynamic_decode()逐步生成输出。

如何获取所有步骤的注意力?谢谢!

1 个答案:

答案 0 :(得分:4)

您可以通过在AttentionWrapper定义中设置 alignment_history = True 标志来访问注意力量。

以下是示例:

# Define attention mechanism
attn_mech = tf.contrib.seq2seq.LuongMonotonicAttention(
    num_units = attention_unit_size, memory = decoder_outputs,
    memory_sequence_length = input_lengths)

# Define attention cell
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
    cell = decoder_cell, attention_mechanism = attn_mech,
    alignment_history=True)

# Define train helper
train_helper = tf.contrib.seq2seq.TrainingHelper(
    inputs = encoder_inputs, 
    sequence_length = input_lengths)

# Define decoder
decoder = tf.contrib.seq2seq.BasicDecoder(
    cell = attn_cell, 
    helper = train_helper, initial_state=decoder_initial_state)

# Dynamic decoding
dec_outputs, dec_states, _ = tf.contrib.seq2seq.dynamic_decode(decoder)

然后在会话中,您可以访问权重,如下所示:

with tf.Session() as sess:
    ...
    alignments = sess.run(dec_states.alignment_history.stack(), feed_dict)

最后,您可以像这样想象注意(对齐):

def plot_attention(attention_map, input_tags = None, output_tags = None):    
    attn_len = len(attention_map)

    # Plot the attention_map
    plt.clf()
    f = plt.figure(figsize=(15, 10))
    ax = f.add_subplot(1, 1, 1)

    # Add image
    i = ax.imshow(attention_map, interpolation='nearest', cmap='Blues')

    # Add colorbar
    cbaxes = f.add_axes([0.2, 0, 0.6, 0.03])
    cbar = f.colorbar(i, cax=cbaxes, orientation='horizontal')
    cbar.ax.set_xlabel('Alpha value (Probability output of the "softmax")', labelpad=2)

    # Add labels
    ax.set_yticks(range(attn_len))
    if output_tags != None:
      ax.set_yticklabels(output_tags[:attn_len])

    ax.set_xticks(range(attn_len))
    if input_tags != None:
      ax.set_xticklabels(input_tags[:attn_len], rotation=45)

    ax.set_xlabel('Input Sequence')
    ax.set_ylabel('Output Sequence')

    # add grid and legend
    ax.grid()

    plt.show()

# input_tags - word representation of input sequence, use None to skip
# output_tags - word representation of output sequence, use None to skip
# i - index of input element in batch

plot_attention(alignments[:, i, :], input_tags, output_tags)

enter image description here