使用CNN模型的图像分类器值错误

时间:2019-12-10 11:48:47

标签: python keras conv-neural-network

您好,我一直在制作具有多个类别的图像分类器,而我不断收到此错误

  

ValueError :检查目标时出错:期望density_1的形状为(1,),但数组的形状为(9,)

idk我在这里做错的是我的代码,通过将类别更改为二进制,可以在2类分类器上正常工作

from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Dense
from keras.layers import Activation, Dropout, Flatten, Dense
from keras.optimizers import Adam
from keras import backend as K
from PIL import ImageFile, Image
print(Image.__file__)
import numpy
import matplotlib.pyplot as plt

# dimensions of our images.
img_width, img_height = 256, 256

train_data_dir = r'C:\Users\Acer\imagerec\Mushrooms\TRAIN'
validation_data_dir = r'C:\Users\Acer\imagerec\Mushrooms\VAL'
nb_train_samples = 6714
nb_validation_samples = 6262
epochs = 20
batch_size = 30

if K.image_data_format() == 'channels_first':
    input_shape = (1, img_width, img_height)
else:
    input_shape = (img_width, img_height, 1)

from keras.applications.vgg19 import VGG19
from keras.models import Model
from keras.layers import Dense

vgg = VGG19(include_top=False, weights='imagenet', input_shape=(), pooling='avg')
x = vgg.output
x = Dense(1, activation='sigmoid')(x)
model = Model(vgg.input, x)
model.summary()

model.compile(loss='sparse_categorical_crossentropy',
              optimizer='Adam',
              metrics=['accuracy'])

# this is the augmentation configuration we will use for training
train_datagen = ImageDataGenerator(
    rescale=1. / 255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1. / 255)

train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical')

validation_generator = test_datagen.flow_from_directory(
    validation_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical')

model.fit_generator(
    train_generator,
    steps_per_epoch=nb_train_samples // batch_size,
    epochs=epochs,
    validation_data=validation_generator,
    validation_steps=nb_validation_samples // batch_size)

from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
import seaborn as sns

test_steps_per_epoch = numpy.math.ceil(validation_generator.samples / validation_generator.batch_size)

predictions = model.predict_generator(validation_generator, steps=test_steps_per_epoch)
# Get most likely class
predicted_classes = numpy.argmax(predictions, axis=1)
true_classes = validation_generator.classes
class_labels = list(validation_generator.class_indices.keys())
report = classification_report(true_classes, predicted_classes, target_names=class_labels)
print(report)

cm=confusion_matrix(true_classes,predicted_classes)

sns.heatmap(cm, annot=True)

print(cm)

plt.show()

这是错误

Traceback (most recent call last):
  File "C:/Users/Acer/PycharmProjects/condas/VGG19.py", line 69, in <module>
    validation_steps=nb_validation_samples // batch_size)
  File "C:\Users\Acer\Anaconda3\envs\condas\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "C:\Users\Acer\Anaconda3\envs\condas\lib\site-packages\keras\engine\training.py", line 1732, in fit_generator
    initial_epoch=initial_epoch)
  File "C:\Users\Acer\Anaconda3\envs\condas\lib\site-packages\keras\engine\training_generator.py", line 220, in fit_generator
    reset_metrics=False)
  File "C:\Users\Acer\Anaconda3\envs\condas\lib\site-packages\keras\engine\training.py", line 1508, in train_on_batch
    class_weight=class_weight)
  File "C:\Users\Acer\Anaconda3\envs\condas\lib\site-packages\keras\engine\training.py", line 621, in _standardize_user_data
    exception_prefix='target')
  File "C:\Users\Acer\Anaconda3\envs\condas\lib\site-packages\keras\engine\training_utils.py", line 145, in standardize_input_data
    str(data_shape))
ValueError: Error when checking target: expected dense_1 to have shape (1,) but got array with shape (9,)

0 个答案:

没有答案
相关问题