我已经实现了自己的分类器,现在我想对它进行网格搜索,但是我收到了以下错误:estimator.fit(X_train, y_train, **fit_params)
TypeError: fit() takes 2 positional arguments but 3 were given
我关注this tutorial并使用this template提供的scikit's official documentation。我的课程定义如下:
class MyClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, lr=0.1):
self.lr=lr
def fit(self, X, y):
# Some code
return self
def predict(self, X):
# Some code
return y_pred
def get_params(self, deep=True)
return {'lr'=self.lr}
def set_params(self, **parameters):
for parameter, value in parameters.items():
setattr(self, parameter, value)
return self
我正在尝试网格搜索,如下所示:
params = {
'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)
编辑我
这就是我调用它的方式: gs.fit([' hello world','尝试',' hello world','尝试',' hello world&# 39;,'尝试',' hello world','尝试'], ['我' Z','我'' Z','我',' Z','我',' Z'])
END编辑我
错误由文件_fit_and_score
python3.5/site-packages/sklearn/model_selection/_validation.py
方法产生
它用3个参数调用estimator.fit(X_train, y_train, **fit_params)
,但我的估算器只有两个,所以错误对我有意义,但我不知道如何解决它...我也尝试添加一些<{1}}方法的虚假参数,但它没有用。
编辑II
完成错误输出:
fit
END EDIT II
解决 谢谢大家,我有一个愚蠢的错误:有两个不同的功能,同名(适合),(我实现了另一个用于不同参数的自定义目的,只要我重新命名我的&#39;定制适合&#39;,它运作正常。)
谢谢你,抱歉
答案 0 :(得分:2)
以下代码适用于我:
class MyClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, lr=0.1):
# Some code
pass
def fit(self, X, y):
# Some code
pass
def predict(self, X):
# Some code
return X % 3
params = {
'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)
x = np.arange(30)
y = np.concatenate((np.zeros(10), np.ones(10), np.ones(10) * 2))
gs.fit(x, y)
我能想到的最好的结果是,您将某些内容传递到gs.fit
以及x
以外的y
方法,或者您的MyClassifier.fit
方法缺少自我参数。
只有将kwarg传递给fit_params
方法才能填充gs.fit
kwargs,否则它是一个空字典({}
)而**fit_params
不会抛出参数错误。要对此进行测试,请创建分类器的实例并传递**{}
。例如:
clf = MyClassifier()
clf.fit(x, y, **{})
这不会引发位置参数错误。
因此,除非将某些内容传递给gs.fit
, gs.fit(x, y, some_arg=123)
在我看来,你错过了MyClassifier.fit
定义中的一个位置参数。您提供的错误消息似乎支持此假设,因为它指出fit() takes 2 positional arguments but 3 were given
。如果您按如下方式定义了拟合,则需要3个位置参数:
def fit(self, X, y): ...
答案 1 :(得分:0)
看起来像是一些自定义参数的传递。只需添加一个catch-all关键字参数fit-Method:
transform