如何避免CNN过度拟合?

时间:2020-06-01 16:16:09

标签: tensorflow keras deep-learning cnn vgg-net

我正在建立一个模型,通过分析他们的面孔来预测人们的年龄。我正在使用this预训练模型,并制作了自定义损失函数和自定义指标。因此,我获得了离散结果,但我想对其进行改进。特别是,我注意到在某些时期之后,模型开始过度拟合训练集,然后val_loss增大。如何避免这种情况?我已经在使用Dropout,但这似乎还不够。 我想也许我应该使用l1和l2,但我不知道如何。

def resnet_model():
  model = VGGFace(model = 'resnet50')#model :{resnet50, vgg16, senet50}
  xl = model.get_layer('avg_pool').output
  x = keras.layers.Flatten(name='flatten')(xl)
  x = keras.layers.Dense(4096, activation='relu')(x)
  x = keras.layers.Dropout(0.5)(x)
  x = keras.layers.Dense(4096, activation='relu')(x)
  x = keras.layers.Dropout(0.5)(x)
  x = keras.layers.Dense(11, activation='softmax', name='predictions')(x)
  model = keras.engine.Model(model.input, outputs = x)
  return model

model = resnet_model()
initial_learning_rate = 0.0003

epochs = 20; batch_size = 110
num_steps = train_x.shape[0]//batch_size
learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
    [3*num_steps, 10*num_steps, 16*num_steps, 25*num_steps],
    [1e-4, 1e-5, 1e-6, 1e-7, 5e-7]
    )

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate_fn)
model.compile(loss=custom_loss, optimizer=optimizer, metrics=['accuracy', one_off_accuracy])

model.fit(train_x, train_y, epochs=epochs, batch_size=batch_size, validation_data=(test_x, test_y))

这是结果示例:

enter image description here enter image description here enter image description here

2 个答案:

答案 0 :(得分:0)

您可以尝试在训练中加入图像增强,这会增加数据的“样本量”以及@Suraj S Jain 提到的“多样性”。官方教程在这里:https://www.tensorflow.org/tutorials/images/data_augmentation

答案 1 :(得分:0)

有许多正则化方法可帮助您避免过度拟合模型:

辍学: 在训练过程中随机禁用神经元,以强制其他神经元也被训练。

L1/L2 处罚: 惩罚剧烈变化的权重。这试图确保在对输入进行分类时将同等考虑所有参数。

输入端的随机高斯噪声: 在输入处添加随机高斯噪声:x = x + r 其中 r 是范围 [-1, 1] 中的随机正态值。这会混淆您的模型并防止其过度拟合到您的数据集,因为在每个时期,每个输入都会不同。

标签平滑: 您可以平滑这些值(例如 0.1 和 0.9),而不是说目标是 0 或 1。

提前停止: 这是避免过度训练模型的一种非常常见的技术。如果您注意到模型的损失随着验证的准确性而下降,那么这是停止训练的好兆头,因为您的模型开始过度拟合。

K 折交叉验证: 这是一种非常强大的技术,可确保您的模型不会一直使用相同的输入,也不会过度拟合。

数据增强: 通过旋转/移动/缩放/翻转/填充等图像,您可以确保您的模型被迫更好地训练其参数,而不是过度拟合现有数据集。

我很确定还有更多技术可以避免过度拟合。该存储库包含许多示例,说明如何在数据集中部署上述技术: https://github.com/kochlisGit/Tensorflow-State-of-the-Art-Neural-Networks