我正在尝试将keras模型包装到scikit中,以学习网格搜索和管道结构以进行超参数调整。
当build_fn函数采用0个参数在KerasClassifier中使用时,它绝对可以正常工作。但是,只要我使用带有参数的函数,它就会失败
下面给出的示例代码
def prepare_classifier(x, y):
shape_of_input = x.shape
shape_of_target = y.shape
classifier = Sequential()
## number of neurons = 30
## kernel_initializer determines how the weights are initialized
## activation is the activation function at this particular hidden layer
## input_shape is the number of features in a single row.. in this case it is shape_of_input[1]
## shape_of_input[0] is the total number of such rows
classifier.add(Dense(units = 30, activation = 'relu', kernel_initializer = 'uniform', input_dim = shape_of_input[1]))
classifier.add(Dense(units = 30, activation = 'relu', kernel_initializer = 'uniform'))
## we are predicting 10 digits for each row of x.
## in total there are shape_of_input[0] rows in total
classifier.add(Dense(10, activation = 'softmax'))
## categorical_crossentropy is the loss function for multi output loss function
classifier.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])
return classifier
def fit(classifier, x_train, y_train, epoch_size, batch_size = 10):
pipeline = Pipeline([
('keras_classifier', classifier)
])
param_grid = {
'keras_classifier__batch_size' : [10,20,30,50],
'keras_classifier__epochs' : [100, 200, 300],
'keras_classifier__x' : [x_train],
'keras_classifier__y' : [y_train],
}
grid = GridSearchCV(estimator = pipeline, param_grid = param_grid, n_jobs = -1)
grid.fit(x_train, y_train)
print("Best parameters are : ", grid.best_params_, '\n grid best score :', grid.best_score_)
classifier = KerasClassifier(build_fn = prepare_classifier, x = x_train[0:100], y = y_train )
fit(classifier, x_train[:100], y_train, epoch_size )
这是针对某些x和y数据的(p.s.我使用了mnist数据)
我得到的错误是:
RuntimeError:无法克隆object,因为构造函数未设置或修改参数x
但是,如果我的prepare_classifier函数不带任何参数,代码绝对可以正常工作。
我做错了什么?
答案 0 :(得分:0)
解决了。本质上,下一行是问题
classifier = KerasClassifier(build_fn = prepare_classifier, x = x_train[0:100], y = y_train )
需要更改为
classifier = KerasClassifier(build_fn = prepare_classifier)
和prepare_classifier的参数需要使用param_grid发送