Model.predict返回“矩阵大小不兼容”

时间:2019-11-01 13:24:35

标签: python-3.x tensorflow keras

我训练了一个由ModelCheckpoint回调保存的模型。 我加载了它,并使用keras.Model.predict运行了预测,但是出现了“与矩阵大小不兼容”的错误,如下所示。

我检查了要进行预测的数据的形状正确,并且确实如此。

有什么建议吗?

代码


    print("***dataset:")
    print(dataset)

    print("***Show shape")
    iterator = dataset.make_one_shot_iterator()
    next_batch = iterator.get_next()
    try:
        while True:
            data = session.run(next_batch)
            print(data.shape)
    except tf.errors.OutOfRangeError:
        pass

    print("***Load model and predict")
    model = tf.keras.models.load_model(model_file)
    model.summary()
    predictions = model.predict(dataset) # Matrix size-incompatible error

输出

***dataset:
<DatasetV1Adapter shapes: (?, 1, 64, ?), types: tf.float32>
***Show shape
(1, 1, 64, 169)
(1, 1, 64, 169)
(1, 1, 64, 169)
...
(1, 1, 64, 169)
(1, 1, 64, 169)
(1, 1, 64, 169)
(1, 1, 64, 169)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
mels (InputLayer)            [(32, 1, 64, 169)]        0         
_________________________________________________________________
l1_conv (Conv2D)             (32, 32, 62, 167)         288       
_________________________________________________________________
l1_bn (BatchNormalization)   (32, 32, 62, 167)         96        
_________________________________________________________________
l1 (Activation)              (32, 32, 62, 167)         0         
_________________________________________________________________
l1_mp (MaxPooling2D)         (32, 32, 30, 83)          0         
_________________________________________________________________
l2_conv (Conv2D)             (32, 32, 28, 81)          9216      
_________________________________________________________________
l2_bn (BatchNormalization)   (32, 32, 28, 81)          96        
_________________________________________________________________
l2 (Activation)              (32, 32, 28, 81)          0         
_________________________________________________________________
l2_mp (MaxPooling2D)         (32, 32, 13, 40)          0         
_________________________________________________________________
l3_conv (Conv2D)             (32, 32, 11, 38)          9216      
_________________________________________________________________
l3_bn (BatchNormalization)   (32, 32, 11, 38)          96        
_________________________________________________________________
l3 (Activation)              (32, 32, 11, 38)          0         
_________________________________________________________________
l3_mp (MaxPooling2D)         (32, 32, 5, 18)           0         
_________________________________________________________________
flatten (Flatten)            (32, 2880)                0         
_________________________________________________________________
logits (Dense)               (32, 100)                 288100    
_________________________________________________________________
dense (Dense)                (32, 10)                  1010      
=================================================================
Total params: 308,118
Trainable params: 307,926
Non-trainable params: 192

***Load model and predict
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-14-edec6ee91517> in <module>
----> 1 predict('/home/jul/data/xenocanto/audio/wav_22050hz_MLR/XC164420.M.wav', '/home/jul/data/ingerop/subset_1572008350/features/actdet_config.json', '/home/jul/data/ingerop/subset_1572008350/features/featex_config.json', '/home/jul/data/ingerop/subset_1572008350/run_1572428779/models/model.05-0.92.h5')

~/dev/phaunos_ml/phaunos_ml/experiments/ingerop_prediction.py in predict(audio_filename, actdet_cfg_file, featex_cfg_file, model_file)
    106     model.summary()
    107 
--> 108     predictions = model.predict(dataset)
    109 
    110     return predictions

~/.miniconda3/envs/phaunos_ml/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in predict(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)
   1076           verbose=verbose,
   1077           steps=steps,
-> 1078           callbacks=callbacks)
   1079 
   1080   def reset_metrics(self):

~/.miniconda3/envs/phaunos_ml/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py in model_iteration(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_freq, mode, validation_in_fit, prepared_feed_values_from_dataset, steps_name, **kwargs)
    272           # `ins` can be callable in tf.distribute.Strategy + eager case.
    273           actual_inputs = ins() if callable(ins) else ins
--> 274           batch_outs = f(actual_inputs)
    275         except errors.OutOfRangeError:
    276           if is_dataset:

~/.miniconda3/envs/phaunos_ml/lib/python3.6/site-packages/tensorflow/python/keras/backend.py in __call__(self, inputs)
   3290 
   3291     fetched = self._callable_fn(*array_vals,
-> 3292                                 run_metadata=self.run_metadata)
   3293     self._call_fetch_callbacks(fetched[-len(self._fetches):])
   3294     output_structure = nest.pack_sequence_as(

~/.miniconda3/envs/phaunos_ml/lib/python3.6/site-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
   1456         ret = tf_session.TF_SessionRunCallable(self._session._session,
   1457                                                self._handle, args,
-> 1458                                                run_metadata_ptr)
   1459         if run_metadata:
   1460           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument: Matrix size-incompatible: In[0]: [32,90], In[1]: [2880,100]
     [[{{node logits_9/MatMul}}]]
  (1) Invalid argument: Matrix size-incompatible: In[0]: [32,90], In[1]: [2880,100]
     [[{{node logits_9/MatMul}}]]
     [[dense_9/Sigmoid/_2867]]
0 successful operations.
0 derived errors ignored.

0 个答案:

没有答案