如何在tensforflow 2.0中替换OutputProjectionWrapper

时间:2019-08-13 16:57:53

标签: tensorflow keras tensorflow2.0

我有以下带有注意机制的seq2seq解码器代码段。它在tensorflow 1.13中工作。现在我需要使用keras升级到tensorflow 2.0,但是tf.contrib.rnn.OutputProjectionWrapper已经在tensorflow 2.0中删除了。如何执行呢?

attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
                num_units, memory=memory,
         memory_sequence_length=self.encoder_inputs_actual_length)
cell = tf.contrib.rnn.LSTMCell(num_units)
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
                cell, attention_mechanism, attention_layer_size)
out_cell = tf.contrib.rnn.OutputProjectionWrapper(
                attn_cell, self.output_size, reuse=reuse)
decoder = tf.contrib.seq2seq.BasicDecoder(
                cell=out_cell, helper=helper,
                initial_state=out_cell.zero_state(
                    dtype=tf.float32, batch_size=self.batch_size))
final_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder=decoder, output_time_major=True,
                impute_finished=True, 
maximum_iterations=self.input_steps
            )

我读了https://www.oreilly.com/library/view/neural-networks-and/9781492037354/ch04.html,但没有弄清楚如何将完整的连接添加到我的案子中。

我尝试按以下方式在最新模式下使用最新的seq2seq插件,没有语法错误,但是我不确定它是否正确。以前的tf 1.13版本预测准确率可快速达到90%,而新的tf2.0版本准确率始终为60%左右。

attention_mechanism = tfa.seq2seq.BahdanauAttention(num_units,memory,memory_sequence_length)
lstm_cell = layers.LSTMCell(num_units)
attn_cell = tfa.seq2seq.AttentionWrapper(lstm_cell,attention_mechanism, attention_layer_size=num_units) 
output_layer = layers.Dense(self.output_size)
basic_decoder = tfa.seq2seq.BasicDecoder(cell=attn_cell, sampler=sampler,output_layer=output_layer,output_time_major=True,impute_finished=True,maximum_iterations=self.input_steps)
initial_state = attn_cell.get_initial_state(batch_size=self.batch_size,dtype=tf.float32).clone(cell_state=encoder_final_state)
final_outputs, _, _ = basic_decoder(encoder_outputs_sequence,initial_state=initial_state)

谢谢。

1 个答案:

答案 0 :(得分:0)

我最终弄清楚了将精度保持在60%左右的原因是,AttentionWrapper默认情况下会输出注意力得分,但是在我的情况下,我需要实际的输出来计算下一个注意力得分。解决方案是在AttentionWrapper中设置output_attention = False:

attn_cell = tfa.seq2seq.AttentionWrapper(lstm_cell,attention_mechanism, 
attention_layer_size=num_units, output_attention=False) 

在这里更新它,以防有人遇到相同的问题。