我正在尝试将model.predict_classes(x)
用于训练有素的模型,但出现错误:
ValueError: Error when checking input: expected conv2d_1_input to have 4 dimensions, but got array with shape (8, 4)
我用相同的错误信息检查了其他问题,但没有一个解决我的问题,或者至少我无法理解特别解决我的问题意味着什么。
为了能够重现该错误,您可以在Google colab中运行以下代码段:
首次加载数据集:
from google_drive_downloader import GoogleDriveDownloader as gdd
gdd.download_file_from_google_drive(
file_id='13WSlx4cmXh3wfvzNEXbZAbc2a1RHpPPP',
dest_path='./data/klasifikacia.zip',
unzip=True)
然后进行负荷训练模型:
gdd.download_file_from_google_drive(
file_id='1k3Lz79cF40peKCf-UdobyIUdbKZ2Onfp',
dest_path='./data/model.h5',
unzip=False)
最后,运行预测:
import keras
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model
from matplotlib import pyplot as plt
import matplotlib.pyplot as plt
%matplotlib inline
test_path = './data/test'
model_name = 'keras_drone_trained_model.h5'
test_batches = ImageDataGenerator().flow_from_directory(test_path, target_size=(100, 100), classes=['biker', 'pedestrian', 'golf_cart', 'skater'], batch_size=8)
test_datagen = ImageDataGenerator(rescale=1./255)
def plots(ims, figsize=(12,6), rows=1, interp=False, titles=None):
if type(ims[0]) is np.ndarray:
ims = np.array(ims).astype(np.uint8)
if(ims.shape[-1] != 3):
ims = ims.transpose((0,2,3,1))
f = plt.figure(figsize=figsize)
cols = len(ims)//rows if len(ims)%2 == 0 else len(ims)//rows + 1
for i in range(len (ims)):
sp = f.add_subplot(rows, cols, i+1)
sp.axis('Off')
if titles is not None:
sp.set_title(titles[i], fontsize=14)
plt.imshow(ims[i], interpolation=None if interp else 'none')
test_imgs, test_labels = next(test_batches)
plots(test_imgs, titles=test_labels)
classes = test_batches.class_indices
print(classes)
model = load_model('data/model.h5')
x = model.predict_generator(test_batches, steps=1, verbose=0)
predict = model.predict_classes(x)
predict
所以输入应该有4个维,但是什么数组的形状为(8,4)? 我的代码中的问题出在哪里?如果您还需要查看我的模型,请在评论中告诉我。
编辑:
summary()
的几行:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_7 (Conv2D) (None, 100, 100, 32) 896
_________________________________________________________________
activation_7 (Activation) (None, 100, 100, 32) 0
_________________________________________________________________
batch_normalization_7 (Batch (None, 100, 100, 32) 128
_________________________________________________________________
conv2d_8 (Conv2D) (None, 100, 100, 32) 9248