我正使用Keras
中的GridSearchCV
调整sklearn
模型中的超参数,如this tutorial
model = KerasClassifier(build_fn=create_model, verbose=0)
batch_size = [10, 20, 40, 60, 80, 100]
epochs = [10, 50, 100]
param_grid = dict(batch_size=batch_size, epochs=epochs)
grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1)
他们调用fit
方法找到最佳超参数
grid_result = grid.fit(X, Y)
但是,让我们说我想要更改batch_sizes
并再次调用fit
(无需在Jupyter中重新启动内核)。
batch_size = [15, 20, 25, 30, 35, 40]
param_grid = dict(batch_size=batch_size, epochs=epochs)
grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1)
当我致电fit
grid_result = grid.fit(X, Y)
它无休止地工作并且不会终止。为了适应这些已更改的参数,我必须重新启动内核,然后重新加载数据,模块等。
问题。如何在不重新启动内核的情况下第二次在fit
上调用GridSearchCV
?
详细信息。我使用this data。详细摘录:
import numpy as np
from sklearn.model_selection import GridSearchCV
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier
def create_model():
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
dataset = np.loadtxt("data/pima-indians-diabetes.data.csv", delimiter=",")
X = dataset[:,0:8]
y = dataset[:,8]
model = KerasClassifier(build_fn=create_model, verbose=0)
batch_size = [10, 20, 40, 60, 80, 100]
epochs = [10, 50]
param_grid = dict(batch_size = batch_size, epochs = epochs)
grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1)
然后我调用fit
并且它可以正常工作
grid_result = grid.fit(X, y)
print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
出:Best: 0.690104 using {'batch_size': 10, 'epochs': 50}
然后我运行以下内容进行一些更改:
batch_size = [5, 10, 15, 20]
param_grid = dict(batch_size = batch_size, epochs = epochs)
grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1)
最后我第二次打电话给fit
,永远不会停止
grid_result = grid.fit(X, y)
答案 0 :(得分:2)
我可以在MacBook Pro上重现错误。
我也使用了pima-indians-diabetes.data.csv
数据集。
这里的问题是tensorflow会话。如果在GridSearchCV.fit()
之前在父进程中创建了一个会话,它肯定会挂起。
一种可能的解决方案是将所有会话创建代码限制为KerasClassifer
类和模型创建函数。
此外,您可能希望在模型创建函数或KerasClassifier
的子类中限制TF的内存使用。
快速解决方案:
n_jobs = 1
但需要很长时间才能完成。
<强>参考文献:强>
Session hang issue with python multiprocessing
Keras + Tensorflow and Multiprocessing in Python
Limit the resource usage for tensorflow backend