tf.contrib.seq2seq.TraininHelper在Bahdanau seq2seq实现中期望得到什么输入?

时间:2017-07-06 21:30:41

标签: tensorflow lstm rnn

阅读了Bahdanau paper并将其翻译成当前的tf.contrib.seq2seq API后,我对我应该提供给解码器的内容感到困惑。特别是,TrainingHelper看起来应该会收到一个时移的标签列表。

以下是我的工作示例,但我不确定它是否正确。

# Given: 
#    annotations: encoder outputs, reshaped to 
#    (batch_size, time, encoder_size)
#    labels: ground truth, shaped (batch_size, FORECAST_HORIZON)
if params.get('ATTENTION') == 'Bahdanau':
    bahdanau = tf.contrib.seq2seq.BahdanauAttention(
        num_units=ATTENTION_SIZE,
        memory=annotations,
        normalize=False,
        name='BahdanauAttention')
    attn_cell = tf.contrib.seq2seq.AttentionWrapper(
        cell=tf.nn.rnn_cell.BasicLSTMCell(DECODER_SIZE, forget_bias=1.0),
        attention_mechanism=bahdanau,
        output_attention=False,
        name="attention_wrapper")
    helper = tf.contrib.seq2seq.TrainingHelper(
        inputs=annotations,  # ??????
        sequence_length=[WINDOW_LENGTH]*BATCH_SIZE,
        name="TrainingDecoderHelper")

请注意倒数第三行。

TrainingHelper是否应该将编码器注释输入注意力集中的解码器系统?

  • pro:如果inputs形状不像annotations,则AttentionWrapper最终抱怨形状 - 系统中唯一出现这种形状的地方是编码器。
  • con:如果这是正确的话,解码器在哪里得到基本事实?
  • con:注意力缠绕的解码器(attn_cell)已经知道在哪里获得注释(这不是注意机制的重点吗?)

无论如何,实际上,我得到了一个可训练的系统,但对我来说似乎有点可疑(包括它相对于简单的LSTM表现不佳的事实,但这绝对是目前切向)。

0 个答案:

没有答案