我以某种方式滥用TensorFlow的tf.contrib.seq2seq
模块,但没有产生错误,因此我无法找到错误。我的问题是我的解码器为输出序列中的每个输出输出相同的值(在我的情况下,在0和3之间的分类标签,包括0和3)。在下面的示例中,我的输出序列有8个标签。
我的代码:
attention_mechanism = BahdanauAttention(num_units=ATTENTION_UNITS,
memory=encoder_outputs,
normalize=True)
attention_wrapper = AttentionWrapper(cell=self._create_lstm_cell(DECODER_SIZE),
attention_mechanism=attention_mechanism,
attention_layer_size=None)
attention_zero = attention_wrapper.zero_state(batch_size=self.x.shape[0], dtype=tf.float32)
# concatenate c1 and c2 from encoder final states
new_c = tf.concat([encoder_final_states[0].c, encoder_final_states[1].c], axis=1)
# concatenate h1 and h2 from encoder final states
new_h = tf.concat([encoder_final_states[0].h, encoder_final_states[1].h], axis=1)
# define initial state using concatenated h states and c states
init_state = attention_zero.clone(cell_state=LSTMStateTuple(c=new_c, h=new_h))
training_helper = TrainingHelper(inputs=self.y_actual, # feed in ground truth
sequence_length=output_length) # feed in sequence length
decoder = BasicDecoder(cell=attention_wrapper,
helper=training_helper,
initial_state=init_state
)
decoder_outputs, decoder_final_state, decoder_final_sequence_lengths = dynamic_decode(decoder=decoder,
impute_finished=True)
我需要创建LSTMStateTuple
,因为我的编码器使用双向RNN。
我怀疑错误在解码器中,因为我的编码器的输出没有任何均匀性。但是,我可能是错的。
答案 0 :(得分:0)
问题确实是我需要设置output_attention=False
,因为我正在使用Bahdanau Attention。