如何使用GridSearch保存具有最佳参数的Keras模型

时间:2018-02-22 08:34:25

标签: python machine-learning neural-network keras

有什么方法可以保存使用Gridsearch获得的最佳参数的完整Keras模型。

我有以下Keras型号:

def create_model(init_mode='uniform'):
    n_x_new=train_selected_x.shape[1]

    model = Sequential()
    model.add(Dense(n_x_new, input_dim=n_x_new, kernel_initializer=init_mode, activation='sigmoid'))
    model.add(Dense(10, kernel_initializer=init_mode, activation='sigmoid'))
    model.add(Dropout(0.8))

    model.add(Dense(1, kernel_initializer=init_mode, activation='sigmoid'))


    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

    return model

seed = 7
np.random.seed(seed)


model = KerasClassifier(build_fn=create_model, epochs=30, batch_size=400, verbose=1)

init_mode = ['uniform', 'lecun_uniform', 'normal', 'zero', 'glorot_normal', 'glorot_uniform', 'he_normal', 'he_uniform']
param_grid = dict(init_mode=init_mode)
#cv = PredefinedSplit(test_fold=my_test_fold)
grid = GridSearchCV(estimator=model, param_grid=param_grid,scoring='roc_auc',cv = PredefinedSplit(test_fold=my_test_fold), n_jobs=1)
grid_result = grid.fit(np.concatenate((train_selected_x, test_selected_x), axis=0), np.concatenate((train_selected_y, test_selected_y), axis=0))



print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
means = grid_result.cv_results_['mean_test_score']
stds = grid_result.cv_results_['std_test_score']
params = grid_result.cv_results_['params']
for mean, stdev, param in zip(means, stds, params):
    print("%f (%f) with: %r" % (mean, stdev, param))

我发现我可以使用callbackcheckpoint方法,但我不知道在原始代码中将此方法所需的代码放在何处。

我在研究时遇到的代码如下。

filepath="weights.best.hdf5"
    checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
    callbacks_list = [checkpoint]

1 个答案:

答案 0 :(得分:0)

环顾四周后,它似乎必须像

一样简单
classifier = KerasClassifier(build_fn=DNN, nb_epoch=32, batch_size=8, callbacks=[your_callback], verbose=1)

但这似乎也不起作用。 可能的解决方法来自给出的答案 - Can I send callbacks to a KerasClassifier?,这应该有所帮助。

  

这是使用不同的通用工具的结果   并没有特别想到与所有可能的一起使用   配置。

此外,您可以参考此票证 - How to pass callbacks to scikit_learn wrappers (e.g. KerasClassifier) #4278

希望它有所帮助!