我正在尝试继承BaseEstimator
和MetaEstimatorMixin
以创建base_estimator
的包装,但我遇到了问题。我试图遵循存储库中的base_ensemble
代码,但它没有帮助。我在下面运行调用TypeError: get_params() missing 1 required positional argument: 'self'
的测试时得到check_estimator(Wrapper)
。根据文档,如果我继承自BaseEstimator
,我不必实现get_params。似乎某事是一个阶级而不是一个实例,但我无法确定它。
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, MetaEstimatorMixin, clone
from functools import lru_cache
import numpy as np
from sklearn.linear_model import LogisticRegression
'''
this is a module containing classes which wraps a classifier or a regressor sklearn estimator
'''
class Wrapper(BaseEstimator, MetaEstimatorMixin):
def __init__(self, base_estimator=LogisticRegression, estimator_params=None):
super().__init__()
self.base_estimator = base_estimator
self.estimator_params = estimator_params
def fit(self, x, y):
self.model = self._make_estimator().fit(x,y)
def _make_estimator(self):
"""Make and configure a copy of the `base_estimator_` attribute.
Warning: This method should be used to properly instantiate new
sub-estimators. taken from sklearn github
"""
estimator = self.base_estimator()
estimator.set_params(**dict((p, getattr(self, p))
for p in self.estimator_params))
return estimator
def predict(self, x):
self.model.predict(x)
import unittest
from sklearn.utils.estimator_checks import check_estimator
class Test(unittest.TestCase):
def test_check_estimator(self):
check_estimator(Wrapper)
答案 0 :(得分:3)
base_estimator
字段,而不是Class。
....
def __init__(self, base_estimator=LogisticRegression(), ...
....
您的错误发生是因为在某些测试中使用了clone(safe = False)。
safe: boolean, optional If safe is false, clone will fall back to a deepcopy on objects that are not estimators.