我想用scikit-learn cross_val_score()
函数对我的Keras神经网络进行交叉验证。
问题是,每次折叠后不仅会记住结果,还会影响整个Keras模型。因此,我想在每次折叠后使用K.clear_session()
清除此模型。但这只是背景的细节。
我的主要问题是:如何使用scikit-learn中的cross_val_score()在每次折叠后运行自定义函数?换句话说:可以运行应该在每次折叠后运行的回调?还是存在其他解决方法?
答案 0 :(得分:0)
您可以创建自定义回调并重新编写此回调的on_train_end(self,logs = {})方法。这种新方法将在每个训练步骤结束时完成。这样的事情:
class CustomCall(Callback):
def __init__(self):
super(CustomCall, self).__init__()
def on_epoch_begin(self, epoch, logs={}):
return
def on_epoch_end(self, epoch, logs={}):
return
def on_batch_begin(self, batch, logs={}):
return
def on_train_end(self, logs={}):
# Stuff here
print('\n Delete previous trained model : ')
K.clear_session()
return