Keras:过拟合模型?

时间:2020-01-27 12:45:53

标签: python machine-learning keras deep-learning classification

我正在尝试使用NIH(国家医学图书馆)的疟疾数据集创建一个二进制图像分类模型,其中包含每个类别(已感染/未感染)的大约27,000张图像。

似乎存在过度拟合的问题,我尝试使用不同的批处理大小,每个时期/验证步骤的步骤,使用不同的隐藏层并添加回调等。该图始终显示一条直线,该直线要么急剧增加要么减少,而不是随着学习的增长而稳定增长(根据我的理解,应该是这样)。下面是一个示例,大多数结果与此类似。

Example plot

我是深度学习的新手,我已经阅读了很多有关过度拟合和试图找到解决方案的信息。但是我认为一定是我在做错和/或误解。如果某人能够发现看起来不正确的东西并能够将我指向正确的方向,将不胜感激!

from keras.layers import MaxPooling2D, Conv2D, Flatten, Dense, Dropout
from keras_preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
from keras.models import Sequential
import matplotlib.pyplot as plt
import constants as c
import numpy as np
import keras

# Clear session and instantiate model
keras.backend.clear_session()
model = Sequential()

# Load images & labels
cells = np.load(c.cells_path)
labels = np.load(c.labels_path)

# Shuffle the entire dataset
n = np.arange(cells.shape[0])
np.random.shuffle(n)

# Update numpy files with shuffled data
cells = cells[n]
labels = labels[n]

# Split the dataset into train/validation/test
train_x, test_x, train_y, test_y = train_test_split(cells, labels, test_size=1 - c.train_ratio, shuffle=False)
val_x, test_x, val_y, test_y = train_test_split(test_x, test_y, test_size=c.test_ratio / (c.test_ratio + c.val_ratio),
                                                shuffle=False)

# The amount of images in each set
print('Training data shape: ', train_x.shape)
print('Validation data shape: ', val_x.shape)
print('Testing data shape: ', test_x.shape)

# Neural network
model.add(Conv2D(32, (3, 3), input_shape=c.input_shape, activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(32, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(32, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(units=64, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(units=1, activation='sigmoid'))

# Compile the model
model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])

# Data augmentation
train_datagen = ImageDataGenerator(rescale=1. / 255,
                                   rotation_range=20,
                                   width_shift_range=0.05,
                                   height_shift_range=0.05,
                                   shear_range=0.05,
                                   zoom_range=0.05,
                                   horizontal_flip=True,
                                   fill_mode='nearest')

validation_datagen = ImageDataGenerator(rescale=1. / 255)
testing_datagen = ImageDataGenerator(rescale=1. / 255)

training_dataset = train_datagen.flow(train_x, train_y, batch_size=32)
validation_dataset = validation_datagen.flow(val_x, val_y, batch_size=32)
testing_dataset = validation_datagen.flow(val_x, val_y, batch_size=32)

# Add callbacks to prevent overfitting
es = EarlyStopping(monitor='accuracy',
                   min_delta=0,
                   patience=2,
                   verbose=0,
                   mode='max')

rlrop = ReduceLROnPlateau(monitor='val_loss',
                          factor=0.2,
                          patience=0.5,
                          min_lr=0.001)

checkpoint = ModelCheckpoint("Model.h5")

# Perform backpropagation and update weights in model
history = model.fit_generator(training_dataset,
                              epochs=50,
                              validation_data=validation_dataset,
                              callbacks=[es, checkpoint, rlrop])

# Save model & weights
model.save_weights("Model_weights.h5")
model.save("Model.h5")

# Plot accuracy graph
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

1 个答案:

答案 0 :(得分:1)

似乎并不过分。在不多看的情况下,可以执行以下操作:

  1. 在第一层将过滤器保留在32,然后在随后的每个卷积层上逐渐加倍。

  2. 由于图像中的变化不会显着降低辍学率。

奇怪的是,这是我第一次尝试Tensorflow 2.0时构建的,您可以here对其进行检查。