多类别分类:验证集准确度高,但测试集预测准确

时间:2019-03-29 11:54:03

标签: keras deep-learning classification conv-neural-network multiclass-classification

我正在尝试对属于16类的图像进行分类。图像具有不同的几何形状(see Fig. 2)。训练集包含16 x 320 = 5120张图像,验证集包含16 x 160 = 2560张图像,测试集包含16 x 2 = 32张图像。

我正在使用以下代码来构建CNN并进行预测。

import numpy as np
np.random.seed(0)

import keras
from keras.models import Sequential,Input,Model
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.layers.normalization import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU 
from keras import regularizers
from keras.layers import Activation

num_classes = 16
classifier = Sequential()
classifier.add(Conv2D(32, kernel_size=(3, 3),activation='relu',input_shape=(64, 64, 3),padding='same'))
classifier.add(MaxPooling2D((2, 2),padding='same'))

classifier.add(Dropout(0.2))

classifier.add(Conv2D(64, (3, 3), activation='relu',padding='same'))
#classifier.add(LeakyReLU(alpha=0.1))
classifier.add(MaxPooling2D(pool_size=(2, 2),padding='same'))

classifier.add(Dropout(0.2))

classifier.add(Conv2D(64, (3, 3), activation='relu',padding='same'))
classifier.add(MaxPooling2D(pool_size=(2, 2),padding='same'))

classifier.add(Dropout(0.25))

classifier.add(Conv2D(128, (3, 3), activation='relu',padding='same'))                 
classifier.add(MaxPooling2D(pool_size=(2, 2),padding='same'))

classifier.add(Dropout(0.25))

classifier.add(Flatten())
classifier.add(Dense(128, activation='relu'))    

classifier.add(Dropout(0.25))

classifier.add(Dense(num_classes, activation='softmax'))


# Compiling the CNN
classifier.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])

from keras.preprocessing.image import ImageDataGenerator
from IPython.display import display
from PIL import Image

train_datagen = ImageDataGenerator(rescale = 1./255,
                                   shear_range = 0.2,
                                   zoom_range = 0.2,
                                   width_shift_range=0.1, 
                                   height_shift_range=0.1)                              

test_datagen = ImageDataGenerator(rescale = 1./255)

training_set = train_datagen.flow_from_directory('dataset/training_set',
                                                 target_size = (64, 64),
                                                 batch_size = 32,
                                                 class_mode = 'categorical')

test_set = test_datagen.flow_from_directory('dataset/test_set',
                                            target_size = (64, 64),
                                            batch_size = 32,
                                            class_mode = 'categorical')

from keras.callbacks import ModelCheckpoint
from keras.callbacks import EarlyStopping

STEP_SIZE_TRAIN = training_set.n//training_set.batch_size
STEP_SIZE_TEST = test_set.n//test_set.batch_size

early_stopping_callback = EarlyStopping(monitor='val_loss', patience=3)
checkpoint_callback = ModelCheckpoint('model' + '.h5', monitor='val_loss', verbose=1, save_best_only=True, mode='min')

classifier.fit_generator(training_set,
                    steps_per_epoch = STEP_SIZE_TRAIN,
                    epochs = 10,
                    validation_data = test_set,
                    validation_steps = STEP_SIZE_TEST,
                    callbacks=[early_stopping_callback, checkpoint_callback],
                    workers = 32)

from keras.models import load_model
model = load_model('model.h5')


# Part 3 - making new predictions
import numpy as np
from keras.preprocessing import image
for i in range(1,33):
    test_image = image.load_img('dataset/single_prediction/Image ' + str(i) +'.bmp', target_size = (64, 64))
    test_image = image.img_to_array(test_image)
    test_image = np.expand_dims(test_image, axis = 0)
    #print(model.predict(test_image)[0])
    print(model.predict(test_image)[0].argmax()+1)

对于培训和验证的准确性和损失,我得到以下结果。

Epoch 1/10
160/160 [==============================] - 29s 179ms/step - loss: 1.3693 - acc: 0.5299 - val_loss: 0.1681 - val_acc: 0.9297

Epoch 00001: val_loss improved from inf to 0.16809, saving model to model.h5
Epoch 2/10
160/160 [==============================] - 18s 112ms/step - loss: 0.2668 - acc: 0.8984 - val_loss: 0.0773 - val_acc: 0.9699

Epoch 00002: val_loss improved from 0.16809 to 0.07725, saving model to model.h5
Epoch 3/10
160/160 [==============================] - 18s 111ms/step - loss: 0.1469 - acc: 0.9482 - val_loss: 0.0133 - val_acc: 1.0000

Epoch 00003: val_loss improved from 0.07725 to 0.01327, saving model to model.h5
Epoch 4/10
160/160 [==============================] - 18s 111ms/step - loss: 0.0990 - acc: 0.9650 - val_loss: 0.0147 - val_acc: 1.0000

Epoch 00004: val_loss did not improve from 0.01327
Epoch 5/10
160/160 [==============================] - 18s 113ms/step - loss: 0.0700 - acc: 0.9740 - val_loss: 7.3014e-04 - val_acc: 1.0000

Epoch 00005: val_loss improved from 0.01327 to 0.00073, saving model to model.h5
Epoch 6/10
160/160 [==============================] - 18s 114ms/step - loss: 0.0545 - acc: 0.9809 - val_loss: 0.0012 - val_acc: 1.0000

Epoch 00006: val_loss did not improve from 0.00073
Epoch 7/10
160/160 [==============================] - 18s 111ms/step - loss: 0.0374 - acc: 0.9865 - val_loss: 0.0101 - val_acc: 1.0000

Epoch 00007: val_loss did not improve from 0.00073
Epoch 8/10
160/160 [==============================] - 18s 111ms/step - loss: 0.0489 - acc: 0.9832 - val_loss: 0.0200 - val_acc: 0.9992

当尝试在测试集的32个图像上测试模型时,我只有3个正确的预测。所以我的问题是:

1)为什么我的验证精度很高,但是测试集上的模型却失败了?

2)如何显示验证集的随机样本(例如10张图像)及其预测类,以了解CNN在验证集上的工作方式?

3)关于如何提高测试仪精度的任何一般性提示?

感谢您的帮助! 非常感谢:)

1 个答案:

答案 0 :(得分:0)

发生这种情况的原因可能很多,但我要解决的一个问题是您的训练,验证和测试集之间的数据分配差异。

在理想情况下,您应该使训练,验证和测试集来自同一分发。尽管这是一个概念,但实际上它不仅意味着每个类具有相同数量的数据点,而且还意味着许多其他维度。这样的维度可能是每个分割的图像质量,即您应该避免训练集和有效集中的图像质量高而测试集的图像质量低。而且还有更多这样的维度,理想情况下,沿着这些维度应该是相同的。

因此,这是一场机会游戏,并且由于运气不好并且无论该事件的可能性有多低,您都希望避免有机会得到与其他测试集不同的测试集分裂。

要克服这一点:

1)在拆分数据之前重新整理数据时选择其他种子

2)为测试和验证集选择相等的拆分大小

3)使用k倍交叉验证

您可以单独或组合执行以上任何一个步骤。您可以在此处了解更多信息:

https://cs230-stanford.github.io/train-dev-test-split.html

希望这会有所帮助