如何解决“ ValueError:无法将输入数组(2,...)广播到(1,...)” model.predict中?

时间:2019-06-21 04:28:48

标签: python numpy tensorflow machine-learning keras

训练完编码器解码器网络后,我将其用于进行预测。但是,给与我用于训练的数据相同会给我带来价值错误。

我尝试设置'batch_size = 1'。我试过仅在单个输入上运行它。


def build(embedding_size, hidden_size):
    encoder_inputs = Input(shape=(None, embedding_size))
    encoder = Bidirectional(LSTM(hidden_size, return_state=True))
    encoder_outputs, forward_h, forward_c, backward_h, backward_c = encoder(encoder_inputs)
    encoder_states = [tf.concat([forward_h, backward_h], 0), tf.concat([forward_c, backward_h], 0)]

    decoder_inputs = Input(shape=(None, embedding_size))
    decoder_lstm = LSTM(hidden_size, return_sequences=True, return_state=True)
    decoder_outputs, _, _ = decoder_lstm(decoder_inputs,
                                         initial_state=encoder_states)
    decoder_dense = Dense(embedding_size, activation='relu')
    decoder_outputs = decoder_dense(decoder_outputs)

    model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

    return model


model = build(**params)
model.load_weights('model/model-01.hdf5')

#Shape of encoder_input_data is (2752, 43, 241)
#Shape of decoder_input_data is (2752, 43, 241)

predictions = model.predict([encoder_input_data, decoder_input_data], batch_size=1)

Traceback (most recent call last):
  File "C:/Users/mason-act5/NLP/mrs_paraphrase/predict.py", line 48, in <module>
    predictions = model.predict([encoder_input_data, decoder_input_data], batch_size=1)
  File "C:\Users\mason-act5\NLP\rnn\venv\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1078, in predict
    callbacks=callbacks)
  File "C:\Users\mason-act5\NLP\rnn\venv\lib\site-packages\tensorflow\python\keras\engine\training_arrays.py", line 370, in model_iteration
    aggregator.aggregate(batch_outs, batch_start, batch_end)
  File "C:\Users\mason-act5\NLP\rnn\venv\lib\site-packages\tensorflow\python\keras\engine\training_utils.py", line 169, in aggregate
    self.results[i][batch_start:batch_end] = batch_out
ValueError: could not broadcast input array from shape (2,43,241) into shape (1,43,241)

0 个答案:

没有答案