scikit learn:与GridSearchCV兼容的自定义分类器

时间:2018-01-11 16:19:06

标签: python machine-learning scikit-learn

我已经实现了自己的分类器,现在我想对它进行网格搜索,但是我收到了以下错误:estimator.fit(X_train, y_train, **fit_params) TypeError: fit() takes 2 positional arguments but 3 were given

我关注this tutorial并使用this template提供的scikit's official documentation。我的课程定义如下:

class MyClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, lr=0.1):
        self.lr=lr

    def fit(self, X, y):
        # Some code
        return self
    def predict(self, X):
        # Some code
        return y_pred
    def get_params(self, deep=True)
        return {'lr'=self.lr}
    def set_params(self, **parameters):
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self

我正在尝试网格搜索,如下所示:

params = {
    'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)

编辑我

这就是我调用它的方式:     gs.fit([' hello world','尝试',' hello world','尝试',' hello world&# 39;,'尝试',' hello world','尝试'],            ['我' Z','我'' Z','我',' Z','我',' Z'])

END编辑我

错误由文件_fit_and_score

中的python3.5/site-packages/sklearn/model_selection/_validation.py方法产生

它用3个参数调用estimator.fit(X_train, y_train, **fit_params),但我的估算器只有两个,所以错误对我有意义,但我不知道如何解决它...我也尝试添加一些<{1}}方法的虚假参数,但它没有用。

编辑II

完成错误输出:

fit

END EDIT II

解决 谢谢大家,我有一个愚蠢的错误:有两个不同的功能,同名(适合),(我实现了另一个用于不同参数的自定义目的,只要我重新命名我的&#39;定制适合&#39;,它运作正常。)

谢谢你,抱歉

2 个答案:

答案 0 :(得分:2)

以下代码适用于我:

class MyClassifier(BaseEstimator, ClassifierMixin):
     def __init__(self, lr=0.1):
         # Some code
         pass
     def fit(self, X, y):
         # Some code
         pass
     def predict(self, X):
         # Some code
         return X % 3

params = {
    'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)

x = np.arange(30)
y = np.concatenate((np.zeros(10), np.ones(10), np.ones(10) * 2))
gs.fit(x, y)

我能想到的最好的结果是,您将某些内容传递到gs.fit以及x以外的y方法,或者您的MyClassifier.fit方法缺少自我参数。

只有将kwarg传递给fit_params方法才能填充gs.fit kwargs,否则它是一个空字典({})而**fit_params不会抛出参数错误。要对此进行测试,请创建分类​​器的实例并传递**{}。例如:

clf = MyClassifier()
clf.fit(x, y, **{})

这不会引发位置参数错误。

因此,除非将某些内容传递给gs.fitgs.fit(x, y, some_arg=123)在我看来,你错过了MyClassifier.fit定义中的一个位置参数。您提供的错误消息似乎支持此假设,因为它指出fit() takes 2 positional arguments but 3 were given。如果您按如下方式定义了拟合,则需要3个位置参数:

def fit(self, X, y): ...

答案 1 :(得分:0)

看起来像是一些自定义参数的传递。只需添加一个catch-all关键字参数fit-Method:

transform