如何使用tensorflow SimpleRNNCell过程批处理数据集?

时间:2019-03-28 12:43:04

标签: tensorflow keras deep-learning

我正在使用Tensorflow创建Seq2Seq模型。我尝试使用迷你批处理来处理数据集。当我在Tensorflow中使用batch()方法构建数据集时,数据集形状变为(None,10)。但是,将数据馈送到SimpleRNNCell时会引发错误:

ValueError: Shape must be rank 2 but is rank 1 for 'simple_rnn_cell/MatMul_1' (op: 'MatMul') with input shapes: [10], [10,10].

代码如下:

    def decoder(self, input_x, real_y, encoder_outputs, training=False):

      decoder_state, cell_states = encoder_outputs, []

      predict_shape = (5, 1)
      output = tf.convert_to_tensor(np.zeros(predict_shape), dtype=tf.float32)

      for x in range(self.max_output):
        # below code raises error, here output and decoder_state shape is  (5, 1) (?, 10)
        output, decoder_state = self.decoder_rnn(output, decoder_state)

0 个答案:

没有答案