我目前正在使用Tensorflow Keras开发用于多分类(3类)的CNN。我曾使用sklearn将数据拆分为9:1训练/验证(1899训练数据,212验证数据)。
我的CNN模型在17个时期后开始缓慢增加。这是否意味着CNN模型开始过度拟合?关于减少验证损失的任何建议,因为我在CNN模型中使用了辍学和批量标准化。我还使用了EarlyStopping来研究我的CNN模型,但是经过一些调整后,我的CNN模型仍然面临着这个问题。
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)
model = Sequential()
# filters, kernel size, input size
model.add(Conv2D(256, (3, 3), input_shape=X.shape[1:], padding='same'))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=2))
model.add(Dropout(0.2))
model.add(Conv2D(256, (3, 3), padding='same'))
model.add(Dropout(0.2))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=2))
model.add(Conv2D(256, (3, 3), padding='same'))
model.add(Dropout(0.25))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=2))
model.add(Conv2D(256, (3, 3), padding='same'))
model.add(Dropout(0.25))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=2))
model.add(Flatten())
model.add(Dense(256))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Dropout(0.8))
model.add(Dense(3))
model.add(Activation('softmax'))
tensorboard = TensorBoard(log_dir="CNN_Model_Rebuilt/logs/{}".format(NAME))
augmented_checkpoint = ModelCheckpoint(
'CNN_Model_Rebuilt/best model/normalization-best.h5',
monitor='val_loss', verbose=0,
save_best_only=True, mode='auto')
es = EarlyStopping(monitor='val_loss',
min_delta=0,
patience=20,
verbose=0, mode='auto')
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train, to_categorical(y_train), batch_size=32, epochs=100,
validation_data=(X_test, to_categorical(y_test)),
callbacks=[augmented_checkpoint, tensorboard, es], verbose=2)
任何建议/建议将不胜感激。谢谢。
答案 0 :(得分:0)
要减少过度拟合,您可以尝试增加输入数据量,然后增强它(翻转,旋转,缩放等),这可以提高泛化能力。
对于模型,您可以尝试的一些想法是增加层数,增加单位数,增加退出值< / strong>,使用不同的激活,添加未使用的新层,以及其他一些内置和自定义的正则化方法。< / p>
您始终可以围绕这些参数和图层进行操作,以获得更准确的模型或有时更差的。这个想法是调整超参数是混合和匹配。 您还可以在此link中阅读有关称为 Tensor Flow Model Optimization Toolkit 的TensorFlow库扩展的信息。