从bidirectional_rnn切换到bidirectional_dynamic_rnn后,结果会变得更糟

时间:2017-02-04 11:03:17

标签: python tensorflow recurrent-neural-network

基本上尝试用bidirectional_dynamic_rnn(重新形成的输入)替换bidirectional_rnn,并且在分类任务上得到更糟糕的结果。 难道我做错了什么?重塑?

bidirectional_rnn version(code excerpt):

encoder_embedded_inputs = [embedding_ops.embedding_lookup(
                    W, encoder_input) for encoder_input in encoder_inputs]

encoder_outputs, encoder_state_fw, encoder_state_bw  = rnn.bidirectional_rnn(
                    encoder_cell_fw,
                    encoder_cell_bw,
                    encoder_embedded_inputs,
                    sequence_length=sequence_length,
                    dtype=dtype)

encoder_state = array_ops.concat(1, [array_ops.concat(
                1, encoder_state_fw), array_ops.concat(1, encoder_state_bw)])
top_states = [array_ops.reshape(e, [-1, 1, cell.output_size * 2])
                              for e in encoder_outputs]
attention_states = array_ops.concat(1, top_states)

分类准确度:95%

bidirectional_dynamic_rnn version(code excerpt):

encoder_embedded_inputs = [embedding_ops.embedding_lookup(
                    W, encoder_input) for encoder_input in encoder_inputs]
emb_size = int(encoder_embedded_inputs[0].get_shape()[1])
enc_size = len(encoder_embedded_inputs)
birnn_inputs = tf.reshape(encoder_embedded_inputs, [-1,enc_size,emb_size])

encoder_outputs, encoder_states  = rnn.bidirectional_dynamic_rnn(
                    encoder_cell_fw,
                    encoder_cell_bw,
                    birnn_inputs,
                    sequence_length=sequence_length,
                    dtype=dtype)
encoder_state_fw, encoder_state_bw = encoder_states
encoder_state = array_ops.concat(1, [array_ops.concat(
                1, encoder_state_fw), array_ops.concat(1, encoder_state_bw)])

attention_states = tf.concat(2, encoder_outputs)

分类准确度:70%

1 个答案:

答案 0 :(得分:0)

好的,所以我发现tf.reshape不适合这个任务,我应该使用tf.stack和tf.transpose代替。 所以基本上它是使用混乱的输入,并且不再能够学习。

错:

birnn_inputs = tf.reshape(encoder_embedded_inputs, [-1,enc_size,emb_size])

右:

birnn_inputs = tf.stack(encoder_embedded_inputs)
birnn_inputs = tf.transpose(birnn_inputs, [1,0,2])

所以现在它运作良好。