包含估算器的sklearn会导致get_params缺失自我错误

时间:2015-12-23 19:28:52

标签: python scikit-learn

我正在尝试继承BaseEstimatorMetaEstimatorMixin以创建base_estimator的包装,但我遇到了问题。我试图遵循存储库中的base_ensemble代码,但它没有帮助。我在下面运行调用TypeError: get_params() missing 1 required positional argument: 'self'的测试时得到check_estimator(Wrapper)。根据文档,如果我继承自BaseEstimator,我不必实现get_params。似乎某事是一个阶级而不是一个实例,但我无法确定它。

from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, MetaEstimatorMixin, clone
from functools import lru_cache
import numpy as np
from sklearn.linear_model import LogisticRegression

'''
this is a module containing classes which wraps a classifier or a regressor sklearn estimator
'''


class Wrapper(BaseEstimator, MetaEstimatorMixin):
    def __init__(self, base_estimator=LogisticRegression, estimator_params=None):
        super().__init__()
        self.base_estimator = base_estimator
        self.estimator_params = estimator_params

    def fit(self, x, y):
        self.model = self._make_estimator().fit(x,y)

    def _make_estimator(self):
        """Make and configure a copy of the `base_estimator_` attribute.
        Warning: This method should be used to properly instantiate new
        sub-estimators. taken from sklearn github
        """
        estimator = self.base_estimator()
        estimator.set_params(**dict((p, getattr(self, p))
                                    for p in self.estimator_params))

        return estimator

    def predict(self, x):
        self.model.predict(x)


import unittest
from sklearn.utils.estimator_checks import check_estimator
class Test(unittest.TestCase):
    def test_check_estimator(self):
        check_estimator(Wrapper)

1 个答案:

答案 0 :(得分:3)

必须使用object初始化

base_estimator字段,而不是Class。

....
def __init__(self, base_estimator=LogisticRegression(), ...
....

您的错误发生是因为在某些测试中使用了clone(safe = False)。

safe: boolean, optional
    If safe is false, clone will fall back to a deepcopy on objects
    that are not estimators.