基本上尝试用bidirectional_dynamic_rnn(重新形成的输入)替换bidirectional_rnn,并且在分类任务上得到更糟糕的结果。 难道我做错了什么?重塑?
bidirectional_rnn version(code excerpt):
encoder_embedded_inputs = [embedding_ops.embedding_lookup(
W, encoder_input) for encoder_input in encoder_inputs]
encoder_outputs, encoder_state_fw, encoder_state_bw = rnn.bidirectional_rnn(
encoder_cell_fw,
encoder_cell_bw,
encoder_embedded_inputs,
sequence_length=sequence_length,
dtype=dtype)
encoder_state = array_ops.concat(1, [array_ops.concat(
1, encoder_state_fw), array_ops.concat(1, encoder_state_bw)])
top_states = [array_ops.reshape(e, [-1, 1, cell.output_size * 2])
for e in encoder_outputs]
attention_states = array_ops.concat(1, top_states)
分类准确度:95%
bidirectional_dynamic_rnn version(code excerpt):
encoder_embedded_inputs = [embedding_ops.embedding_lookup(
W, encoder_input) for encoder_input in encoder_inputs]
emb_size = int(encoder_embedded_inputs[0].get_shape()[1])
enc_size = len(encoder_embedded_inputs)
birnn_inputs = tf.reshape(encoder_embedded_inputs, [-1,enc_size,emb_size])
encoder_outputs, encoder_states = rnn.bidirectional_dynamic_rnn(
encoder_cell_fw,
encoder_cell_bw,
birnn_inputs,
sequence_length=sequence_length,
dtype=dtype)
encoder_state_fw, encoder_state_bw = encoder_states
encoder_state = array_ops.concat(1, [array_ops.concat(
1, encoder_state_fw), array_ops.concat(1, encoder_state_bw)])
attention_states = tf.concat(2, encoder_outputs)
分类准确度:70%
答案 0 :(得分:0)
好的,所以我发现tf.reshape不适合这个任务,我应该使用tf.stack和tf.transpose代替。 所以基本上它是使用混乱的输入,并且不再能够学习。
错:
birnn_inputs = tf.reshape(encoder_embedded_inputs, [-1,enc_size,emb_size])
右:
birnn_inputs = tf.stack(encoder_embedded_inputs)
birnn_inputs = tf.transpose(birnn_inputs, [1,0,2])
所以现在它运作良好。