如何使用tf AttentionWrapper api

时间:2018-08-01 12:40:29

标签: tensorflow rnn attention-model

我想实现一个具有简单自我注意的rnn模型:步骤i的输出是隐藏状态i和以前的隐藏状态(内存)的注意向量的线性组合。下面的代码有意义吗?

rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units)

h_states, final_state = tf.nn.dynamic_rnn(
    rnn_cell, input_tensor, dtype=tf.float32, time_major=False)

attention_mechanism = tf.contrib.seq2seq.LuongAttention(
    num_units, h_states, 
    memory_sequence_length=source_sequence_length)

rnn_cell_with_attention = tf.contrib.seq2seq.AttentionWrapper(
    rnn_cell, attention_mechanism,
    attention_layer_size=num_units)

0 个答案:

没有答案