我正在使用GridSearchCV
中的scikit-learn
在keras
的神经网络中进行网格搜索。我想自定义callback
,以便每次在一个网格点上完成网络训练后,就可以打印出适合的结果。
假设我按如下方式定义网格:
param_grid = dict(epochs=[50, 100, 500, 1000],
learn_rate=[0.1, 0.2, 0.3],
momentum=[0.01, 0.1],
dropout_rate=[0.05, 0.1, 0.15, 0.2])
我计算网格上的可能性总数为:
grid_size = reduce(lambda x,y: x*y,[len(param_grid_[key]) for key in param_grid])
回调是:
from keras.callbacks import ModelCheckpoint, EarlyStopping
# checkpoint
filepath="best_model.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1,
save_best_only=True, mode='max')
# Early stoping
monitor = EarlyStopping(monitor='val_loss', min_delta=1e-5, patience=200,
verbose=1, mode='auto')
callbacks_list = [checkpoint, monitor, LiveGridReport()]
其中LiveGridReport()
是我的自定义回调,它显示有关在网格点上完成训练的消息。
class LiveGridReport(keras.callbacks.Callback):
def __init__(self, grid_size):
grid_size_ = grid_size
def on_train_begin(self, logs={}):
return
def on_train_end(self, logs={}):
return
我的问题是,考虑到我也有EarlyStopping
回调,我无法弄清楚如何检测到网格点上的训练已终止。
答案 0 :(得分:1)
EarlyStopping
回调时停止了哪个时代训练
EarlyStopping.stopped_epoch
或使用历史记录
history = model.fit(....)
number_of_epochs_it_ran = len(history.history['loss'])