Tensorflow中的Seq2Seq,但是我得到了ValueError:图层gru_cell_3的输入0与图层不兼容:

时间:2018-04-02 22:25:37

标签: python tensorflow

我正在研究如何在Tensorflow中为Seq2Seq模型使用RNN,我进入动态RNN的最后一步,我得到一个动态的dynamic_decode步骤,我得到了错误:

" ValueError:图层gru_cell_3的输入0与图层不兼容:expected ndim = 2,found ndim = 1。收到的完整形状:[无]"

import tensorflow as tf

data_inputs = tf.placeholder(tf.float32,[None,102,300]) 

batch_lengths = tf.cast(tf.reduce_sum(tf.reduce_max(tf.sign(data_inputs),2),1),tf.int32)

encoder_cell_forward = tf.nn.rnn_cell.GRUCell(num_units = 150)

encoder_cell_backward = tf.nn.rnn_cell.GRUCell(num_units = 150)

_ , state = tf.nn.bidirectional_dynamic_rnn(
    encoder_cell_forward,encoder_cell_backward,
    data_inputs,sequence_length = batch_lengths,
    dtype = tf.float32 )

state = tf.concat(state,1)

decoder_cell = tf.nn.rnn_cell.GRUCell(num_units = 300)

helper = tf.contrib.seq2seq.TrainingHelper(state,batch_lengths)

projection_layer = tf.layers.Dense(
    units = 300,activation= None,trainable =True )

decoder = tf.contrib.seq2seq.BasicDecoder(
    decoder_cell, helper, state,
    output_layer=projection_layer)

final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
    decoder,maximum_iterations= 102,impute_finished=False)

我在这里做错了什么?

1 个答案:

答案 0 :(得分:1)

我认为这与培训助手一致:

tf.contrib.seq2seq.TrainingHelper(state,batch_lengths)

状态需要是您想要解码的序列批处理,如果使用编码状态则会抛出错误。