Keras提前停止

时间:2017-05-11 03:30:21

标签: python keras conv-neural-network

我使用Keras为我的项目训练神经网络。 Keras提供了早期停止的功能。我是否应该知道应该观察哪些参数以避免我的神经网络通过使用早期停止而过度拟合?

2 个答案:

答案 0 :(得分:112)

early stopping

一旦损失开始增加(或者换句话说,验证准确度开始下降),提前停止基本上会停止训练。根据{{​​3}},它使用如下;

keras.callbacks.EarlyStopping(monitor='val_loss',
                              min_delta=0,
                              patience=0,
                              verbose=0, mode='auto')

值取决于您的实施(问题,批量大小等......)但通常是为了防止过度拟合我会使用;

  1. 监控验证丢失(需要使用交叉 验证或至少训练/测试集)通过设置monitor 'val_loss'的论据。
  2. min_delta是量化某个时代的损失的门槛 改进与否。如果损失的差异低于min_delta,则将其量化 没有改善。最好将它保留为0,因为我们对它感兴趣 当损失变得更糟时。
  3. patience参数表示一旦损失开始增加(停止改善),停止前的纪元数。 如果您使用非常小的批次,这取决于您的实施 或大学习率你的损失 zig-zag (准确度会更嘈杂)所以更好地设置一个 大patience个参数。如果您使用大批量小 学习率你的损失会更顺畅,所以你可以使用 较小的patience参数。无论哪种方式,我都会把它留给2,所以我愿意 给模特更多机会。
  4. verbose决定要打印的内容,保留默认值(0)。
  5. mode参数取决于您监控数量的方向 有(因为它应该减少或增加),因为我们监测损失,我们可以使用min。但是,让我们离开keras 处理我们并将其设置为auto
  6. 所以我会使用这样的东西,并通过绘制错误丢失进行实验,无论是否提前停止。

    keras.callbacks.EarlyStopping(monitor='val_loss',
                                  min_delta=0,
                                  patience=2,
                                  verbose=0, mode='auto')
    

    关于回调如何运作的可能模糊性,我将尝试解释更多。在您的模型上调用fit(... callbacks=[es])后,Keras会调用给定的回调对象预定函数。这些函数可以称为on_train_beginon_train_endon_epoch_beginon_epoch_endon_batch_beginon_batch_end。在每个纪元结束时调用早期停止回调,将最佳监测值与当前监测值进行比较,并在条件满足时停止(自观察到最佳监测值以来已经过去多少个时期并且不仅仅是耐心论证,它们之间的区别是最后一个值大于min_delta等。)。

    正如@BrentFaust在评论中指出的那样,模型的训练将持续到满足早期停止条件或epochs中的fit()参数(默认值= 10)为止。设置Early Stopping回调不会使模型训练超出其epochs参数。因此,使用较大的fit()值调用epochs函数可以从早期停止回调中获益更多。

答案 1 :(得分:1)

以下是另一个项目 AutoKeras (https://autokeras.com/) 中的 EarlyStopping 示例,这是一个自动化机器学习 (AutoML) 库。该库设置了两个 EarlyStopping 参数:patience=10min_delta=1e-4

https://github.com/keras-team/autokeras/blob/5e233956f32fddcf7a6f72a164048767a0021b9a/autokeras/engine/tuner.py#L170

AutoKeras 和 Keras 的默认监控数量是 val_loss

https://github.com/keras-team/keras/blob/cb306b4cc446675271e5b15b4a7197efd3b60c34/keras/callbacks.py#L1748 https://autokeras.com/image_classifier/