KerasClassifier用于带参数的build_fn

时间:2018-12-04 06:47:52

标签: python-3.x scikit-learn keras deep-learning

我正在尝试将keras模型包装到scikit中,以学习网格搜索和管道结构以进行超参数调整。

当build_fn函数采用0个参数在KerasClassifier中使用时,它绝对可以正常工作。但是,只要我使用带有参数的函数,它就会失败

下面给出的示例代码

def prepare_classifier(x, y):

    shape_of_input = x.shape
    shape_of_target = y.shape

    classifier  = Sequential()

    ## number of neurons = 30
    ## kernel_initializer determines how the weights are initialized
    ## activation is the activation function at this particular hidden layer
    ## input_shape is the number of features in a single row.. in this case it is shape_of_input[1]
    ## shape_of_input[0] is the total number of such rows
    classifier.add(Dense(units = 30, activation = 'relu', kernel_initializer = 'uniform', input_dim = shape_of_input[1]))

    classifier.add(Dense(units = 30, activation = 'relu', kernel_initializer = 'uniform'))

    ## we are predicting 10 digits for each row of x.
    ## in total there are shape_of_input[0] rows in total
    classifier.add(Dense(10, activation = 'softmax'))

    ## categorical_crossentropy is the loss function for multi output loss function
    classifier.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])

    return classifier



def fit(classifier, x_train, y_train, epoch_size, batch_size = 10):


    pipeline = Pipeline([
                ('keras_classifier', classifier)
        ])

    param_grid = {

        'keras_classifier__batch_size' : [10,20,30,50],
        'keras_classifier__epochs' : [100, 200, 300],
        'keras_classifier__x' : [x_train],
        'keras_classifier__y' : [y_train],

    }


    grid = GridSearchCV(estimator = pipeline, param_grid = param_grid, n_jobs = -1)
    grid.fit(x_train, y_train)

    print("Best parameters are : ", grid.best_params_, '\n grid best score :', grid.best_score_)



classifier =  KerasClassifier(build_fn = prepare_classifier, x = x_train[0:100], y = y_train )

fit(classifier, x_train[:100], y_train, epoch_size )

这是针对某些x和y数据的(p.s.我使用了mnist数据)

我得到的错误是:

RuntimeError:无法克隆object,因为构造函数未设置或修改参数x

但是,如果我的prepare_classifier函数不带任何参数,代码绝对可以正常工作。

我做错了什么?

1 个答案:

答案 0 :(得分:0)

解决了。本质上,下一行是问题

classifier =  KerasClassifier(build_fn = prepare_classifier, x = x_train[0:100], y = y_train )

需要更改为

classifier =  KerasClassifier(build_fn = prepare_classifier)

和prepare_classifier的参数需要使用param_grid发送