我正在使用Tensorflow创建Seq2Seq模型。我尝试使用迷你批处理来处理数据集。当我在Tensorflow中使用batch()
方法构建数据集时,数据集形状变为(None,10)
。但是,将数据馈送到SimpleRNNCell
时会引发错误:
ValueError: Shape must be rank 2 but is rank 1 for 'simple_rnn_cell/MatMul_1' (op: 'MatMul') with input shapes: [10], [10,10].
代码如下:
def decoder(self, input_x, real_y, encoder_outputs, training=False):
decoder_state, cell_states = encoder_outputs, []
predict_shape = (5, 1)
output = tf.convert_to_tensor(np.zeros(predict_shape), dtype=tf.float32)
for x in range(self.max_output):
# below code raises error, here output and decoder_state shape is (5, 1) (?, 10)
output, decoder_state = self.decoder_rnn(output, decoder_state)