阅读了Bahdanau paper并将其翻译成当前的tf.contrib.seq2seq API后,我对我应该提供给解码器的内容感到困惑。特别是,TrainingHelper看起来应该会收到一个时移的标签列表。
以下是我的工作示例,但我不确定它是否正确。
# Given:
# annotations: encoder outputs, reshaped to
# (batch_size, time, encoder_size)
# labels: ground truth, shaped (batch_size, FORECAST_HORIZON)
if params.get('ATTENTION') == 'Bahdanau':
bahdanau = tf.contrib.seq2seq.BahdanauAttention(
num_units=ATTENTION_SIZE,
memory=annotations,
normalize=False,
name='BahdanauAttention')
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
cell=tf.nn.rnn_cell.BasicLSTMCell(DECODER_SIZE, forget_bias=1.0),
attention_mechanism=bahdanau,
output_attention=False,
name="attention_wrapper")
helper = tf.contrib.seq2seq.TrainingHelper(
inputs=annotations, # ??????
sequence_length=[WINDOW_LENGTH]*BATCH_SIZE,
name="TrainingDecoderHelper")
请注意倒数第三行。
TrainingHelper是否应该将编码器注释输入注意力集中的解码器系统?
inputs
形状不像annotations
,则AttentionWrapper最终抱怨形状 - 系统中唯一出现这种形状的地方是编码器。attn_cell
)已经知道在哪里获得注释(这不是注意机制的重点吗?)无论如何,实际上,我得到了一个可训练的系统,但对我来说似乎有点可疑(包括它相对于简单的LSTM表现不佳的事实,但这绝对是目前切向)。