为什么第二次“适合”调用`GridSearchCV`无休止地工作?

时间:2018-05-29 11:55:00

标签: python machine-learning scikit-learn keras jupyter-notebook

我正使用Keras中的GridSearchCV调整sklearn模型中的超参数,如this tutorial

model = KerasClassifier(build_fn=create_model, verbose=0)
batch_size = [10, 20, 40, 60, 80, 100]
epochs = [10, 50, 100]
param_grid = dict(batch_size=batch_size, epochs=epochs)
grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1)

他们调用fit方法找到最佳超参数

grid_result = grid.fit(X, Y)

但是,让我们说我想要更改batch_sizes并再次调用fit(无需在Jupyter中重新启动内核)。

batch_size = [15, 20, 25, 30, 35, 40]
param_grid = dict(batch_size=batch_size, epochs=epochs)
grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1)

当我致电fit

grid_result = grid.fit(X, Y)

它无休止地工作并且不会终止。为了适应这些已更改的参数,我必须重新启动内核,然后重新加载数据,模块等。

问题。如何在不重新启动内核的情况下第二次在fit上调用GridSearchCV

详细信息。我使用this data。详细摘录:

import numpy as np
from sklearn.model_selection import GridSearchCV
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier

def create_model():
    model = Sequential()
    model.add(Dense(12, input_dim=8, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

dataset = np.loadtxt("data/pima-indians-diabetes.data.csv", delimiter=",")
X = dataset[:,0:8]
y = dataset[:,8]

model = KerasClassifier(build_fn=create_model, verbose=0)

batch_size = [10, 20, 40, 60, 80, 100]
epochs = [10, 50]
param_grid = dict(batch_size = batch_size, epochs = epochs)

grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1)

然后我调用fit并且它可以正常工作

grid_result = grid.fit(X, y)
print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))

出:Best: 0.690104 using {'batch_size': 10, 'epochs': 50}

然后我运行以下内容进行一些更改:

batch_size = [5, 10, 15, 20]
param_grid = dict(batch_size = batch_size, epochs = epochs)
grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1)

最后我第二次打电话给fit,永远不会停止

grid_result = grid.fit(X, y)

1 个答案:

答案 0 :(得分:2)

我可以在MacBook Pro上重现错误。 我也使用了pima-indians-diabetes.data.csv数据集。

这里的问题是tensorflow会话。如果在GridSearchCV.fit()之前在父进程中创建了一个会话,它肯定会挂起。

一种可能的解决方案是将所有会话创建代码限制为KerasClassifer类和模型创建函数。

此外,您可能希望在模型创建函数或KerasClassifier的子类中限制TF的内存使用。

快速解决方案:

n_jobs = 1

但需要很长时间才能完成。

<强>参考文献:

Session hang issue with python multiprocessing

Keras + Tensorflow and Multiprocessing in Python

Limit the resource usage for tensorflow backend

GridSearchCV Hangs On Second Run

scikit-lean GridSearchCV n_jobs != 1 freezing