我正在研究如何在Tensorflow中为Seq2Seq模型使用RNN,我进入动态RNN的最后一步,我得到一个动态的dynamic_decode步骤,我得到了错误:
" ValueError:图层gru_cell_3的输入0与图层不兼容:expected ndim = 2,found ndim = 1。收到的完整形状:[无]"
import tensorflow as tf
data_inputs = tf.placeholder(tf.float32,[None,102,300])
batch_lengths = tf.cast(tf.reduce_sum(tf.reduce_max(tf.sign(data_inputs),2),1),tf.int32)
encoder_cell_forward = tf.nn.rnn_cell.GRUCell(num_units = 150)
encoder_cell_backward = tf.nn.rnn_cell.GRUCell(num_units = 150)
_ , state = tf.nn.bidirectional_dynamic_rnn(
encoder_cell_forward,encoder_cell_backward,
data_inputs,sequence_length = batch_lengths,
dtype = tf.float32 )
state = tf.concat(state,1)
decoder_cell = tf.nn.rnn_cell.GRUCell(num_units = 300)
helper = tf.contrib.seq2seq.TrainingHelper(state,batch_lengths)
projection_layer = tf.layers.Dense(
units = 300,activation= None,trainable =True )
decoder = tf.contrib.seq2seq.BasicDecoder(
decoder_cell, helper, state,
output_layer=projection_layer)
final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
decoder,maximum_iterations= 102,impute_finished=False)
我在这里做错了什么?
答案 0 :(得分:1)
我认为这与培训助手一致:
tf.contrib.seq2seq.TrainingHelper(state,batch_lengths)
状态需要是您想要解码的序列批处理,如果使用编码状态则会抛出错误。