如何在keras中修改ModelCheckPoint以监视val_acc和val_loss并相应地保存最佳模型?

时间:2018-12-12 06:02:26

标签: tensorflow keras

ModelCheckPoint提供了分别保存val_Accval_loss的选项。 我想以某种方式进行修改,以使val_acc有所改进->保存模型。如果val_acc等于先前的最佳val_acc,则检查val_loss,如果val_loss小于先前的最佳val_loss,则保存模型。

    if val_acc(epoch i)> best_val_acc:
        save model
    else if val_acc(epoch i) == best_val_acc:
        if val_loss(epoch i) < best_val_loss:
           save model
        else
           do not save model

3 个答案:

答案 0 :(得分:5)

您可以仅添加两个回调:

callbacks = [ModelCheckpoint(filepathAcc, monitor='val_acc', ...),
             ModelCheckpoint(filepathLoss, monitor='val_loss', ...)]

model.fit(......., callbacks=callbacks)

使用自定义回调

您可以在LambdaCallback(on_epoch_end=saveModel)中做任何您想做的事情。

best_val_acc = 0
best_val_loss = sys.float_info.max 

def saveModel(epoch,logs):
    val_acc = logs['val_acc']
    val_loss = logs['val_loss']

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        model.save(...)
    elif val_acc == best_val_acc:
        if val_loss < best_val_loss:
            best_val_loss=val_loss
            model.save(...)

callbacks = [LambdaCallback(on_epoch_end=saveModel)]

但这与单个ModelCheckpointval_acc没什么不同。除非您使用的样本很少,或者自定义精度相差不大,否则您将不会真正获得相同的准确性。

答案 1 :(得分:0)

您实际上可以检查他们的文档!

为了节省您一些时间,回调ModelCheckpoint接受一个名为save_best_only的参数,它可以完成您想要的事情,只需将其设置为True。这是文档的链接

我误解了你的问题。我想,如果您想使用更复杂的回调类型,则可以始终使用基本的Callback函数,因为您可以同时访问parmasmodel,所以该函数可以提供更多功能。签出docu。您可以先进行测试并打印参数,然后确定要记录的参数。

答案 2 :(得分:0)

here中签出ModelCheckPoint。 model.fit()方法将回调列表作为参数。确保您有类似的东西:

model.fit(..., callbacks=[mcp] ),其中mcp = ModelCheckPoint()已定义。

注意:您可能在回调列表中有多个回调。

为清楚起见,我添加了一些细节,但实际上,它与model.save()函数的作用相同:

class ModelCheckpoint(Callback):
    """Save the model after every epoch.
    `filepath` can contain named formatting options,
    which will be filled the value of `epoch` and
    keys in `logs` (passed in `on_epoch_end`).
    For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
    then the model checkpoints will be saved with the epoch number and
    the validation loss in the filename.
    # Arguments
        filepath: string, path to save the model file.
        monitor: quantity to monitor.
        verbose: verbosity mode, 0 or 1.
        save_best_only: if `save_best_only=True`,
            the latest best model according to
            the quantity monitored will not be overwritten.
        mode: one of {auto, min, max}.
            If `save_best_only=True`, the decision
            to overwrite the current save file is made
            based on either the maximization or the
            minimization of the monitored quantity. For `val_acc`,
            this should be `max`, for `val_loss` this should
            be `min`, etc. In `auto` mode, the direction is
            automatically inferred from the name of the monitored quantity.
        save_weights_only: if True, then only the model's weights will be
            saved (`model.save_weights(filepath)`), else the full model
            is saved (`model.save(filepath)`).
        period: Interval (number of epochs) between checkpoints.
    """