如何在scikit-learn的cross_val_score()中每次折叠后运行函数?

时间:2017-10-12 12:13:53

标签: callback scikit-learn cross-validation

我想用scikit-learn cross_val_score()函数对我的Keras神经网络进行交叉验证。

问题是,每次折叠后不仅会记住结果,还会影响整个Keras模型。因此,我想在每次折叠后使用K.clear_session()清除此模型。但这只是背景的细节。

我的主要问题是:如何使用scikit-learn中的cross_val_score()在每次折叠后运行自定义函数?换句话说:可以运行应该在每次折叠后运行的回调?还是存在其他解决方法?

1 个答案:

答案 0 :(得分:0)

您可以创建自定义回调并重新编写此回调的on_train_end(self,logs = {})方法。这种新方法将在每个训练步骤结束时完成。这样的事情:

class CustomCall(Callback):

    def __init__(self):
        super(CustomCall, self).__init__()

    def on_epoch_begin(self, epoch, logs={}):
        return

    def on_epoch_end(self, epoch, logs={}):
        return

    def on_batch_begin(self, batch, logs={}):
        return

    def on_train_end(self, logs={}):
        # Stuff here
        print('\n Delete previous trained model : ')
        K.clear_session()
        return