这是我输入的代码,用于对某些由鸟类,狗和猫组成的类进行分类。它与二进制分类的代码相同但是当我添加另一个类并更改了编译方法的loss函数以使用categorical_Crossentropy时,它给出了以下错误(=>在代码的末尾)。任何人都可以解释这里的问题或我犯的错误?
# Importing Keras and Tensorflow modules
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Conv2D
from keras.layers import MaxPooling2D
from keras.layers import Flatten
from keras.layers import Dense
from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model
from keras.utils.np_utils import to_categorical
import os.path
# Initilize the CNN
classifier = Sequential()
# Step 1 - Convolution
classifier.add(Conv2D(32, (3, 3), input_shape = (64, 64, 3), activation = 'relu'))
# Step 2 - Pooling
classifier.add(MaxPooling2D(pool_size = (2, 2)))
# Step 2(b) - Add 2nd Convolution Layer making it Deep followed by a Pooling Layer
classifier.add(Conv2D(32, (3, 3), activation = 'relu'))
classifier.add(MaxPooling2D(pool_size = (2, 2)))
# Step 3 - Flattening
classifier.add(Flatten())
# Step 4 - Fully Connected Neural Network
# Hidden Layer - Activation Function RELU
classifier.add(Dense(units = 128, activation = 'relu'))
# Output Layer - Activation Function Softmax(to clasify multiple classes)
classifier.add(Dense(units = 1, activation = 'softmax'))
# Compile the CNN
# Categorical Crossentropy - to classify between multiple classes of images
classifier.compile(optimizer = 'adam', loss = 'categorical_crossentropy',
metrics = ['accuracy'])
# Image Augmentation and Training Section
# Image Augmentation to prevent Overfitting (Applying random transformation on
images to train set.ie.
# scalling, rotating and streching)
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
training_set = train_datagen.flow_from_directory(
'dataset/training_set',
target_size=(64, 64),
batch_size=8,
class_mode='categorical')
test_set = test_datagen.flow_from_directory(
'dataset/test_set',
target_size=(64, 64),
batch_size=8,
class_mode='categorical')
#Fit the clasifier on the CNN data
if(os.path.isfile('my_model.h5') == False):
classifier.fit_generator(
training_set,
steps_per_epoch=8000,
epochs=2,
validation_data=test_set,
validation_steps=2000
)
# Save the generated model to my_model.h5
classifier.save('my_model.h5')
else:
classifier = load_model('my_model.h5')
答案 0 :(得分:2)
您的数据集似乎有3个类,因此您需要将模型定义中的最后一行更改为:
classifier.add(Dense(units = 3, activation = 'softmax'))
答案 1 :(得分:1)
你需要在你的最后一个(密集)层上每个类有一个神经元。
classifier.add(Dense(3))
现在你只有一个神经元,你的网络仍然只设置了两个类。