如何正确使用TensorFlow tensorflow.contrib.seq2seq

时间:2017-06-22 15:54:40

标签: tensorflow

我以某种方式滥用TensorFlow的tf.contrib.seq2seq模块,但没有产生错误,因此我无法找到错误。我的问题是我的解码器为输出序列中的每个输出输出相同的值(在我的情况下,在0和3之间的分类标签,包括0和3)。在下面的示例中,我的输出序列有8个标签。

example

我的代码:

attention_mechanism = BahdanauAttention(num_units=ATTENTION_UNITS,
                                        memory=encoder_outputs,
                                        normalize=True)

attention_wrapper = AttentionWrapper(cell=self._create_lstm_cell(DECODER_SIZE),
                                     attention_mechanism=attention_mechanism,
                                     attention_layer_size=None)

attention_zero = attention_wrapper.zero_state(batch_size=self.x.shape[0], dtype=tf.float32)

# concatenate c1 and c2 from encoder final states
new_c = tf.concat([encoder_final_states[0].c, encoder_final_states[1].c], axis=1)

# concatenate h1 and h2 from encoder final states
new_h = tf.concat([encoder_final_states[0].h, encoder_final_states[1].h], axis=1)

# define initial state using concatenated h states and c states
init_state = attention_zero.clone(cell_state=LSTMStateTuple(c=new_c, h=new_h))

training_helper = TrainingHelper(inputs=self.y_actual,  # feed in ground truth
                                 sequence_length=output_length)  # feed in sequence length

decoder = BasicDecoder(cell=attention_wrapper,
                       helper=training_helper,
                       initial_state=init_state
                       )

decoder_outputs, decoder_final_state, decoder_final_sequence_lengths = dynamic_decode(decoder=decoder,
                                                                                      impute_finished=True)

我需要创建LSTMStateTuple,因为我的编码器使用双向RNN。

我怀疑错误在解码器中,因为我的编码器的输出没有任何均匀性。但是,我可能是错的。

1 个答案:

答案 0 :(得分:0)

问题确实是我需要设置output_attention=False,因为我正在使用Bahdanau Attention。