继承自scikit-learn的LassoCV模型

时间:2016-10-13 15:40:13

标签: python scikit-learn

我尝试使用继承来扩展scikit-learn的RidgeCV模型:

from sklearn.linear_model import RidgeCV, LassoCV

class Extended(RidgeCV):
    def __init__(self, *args, **kwargs):
        super(Extended, self).__init__(*args, **kwargs)

    def example(self):
        print 'Foo'


x = [[1,0],[2,0],[3,0],[4,0], [30, 1]]
y = [2,4,6,8, 60]
model = Extended(alphas = [float(a)/1000.0 for a in range(1, 10000)])
model.fit(x,y)
print model.predict([[5,1]])

它工作得非常好,但是当我尝试从LassoCV继承时,它产生了以下追溯:

Traceback (most recent call last):
  File "C:/Python27/so.py", line 14, in <module>
    model.fit(x,y)
  File "C:\Python27\lib\site-packages\sklearn\linear_model\coordinate_descent.py", line 1098, in fit
    path_params = self.get_params()
  File "C:\Python27\lib\site-packages\sklearn\base.py", line 214, in get_params
    for key in self._get_param_names():
  File "C:\Python27\lib\site-packages\sklearn\base.py", line 195, in _get_param_names
    % (cls, init_signature))
RuntimeError: scikit-learn estimators should always specify their parameters in the signature of their __init__ (no varargs). <class '__main__.Extended'> with constructor (<self>, *args, **kwargs) doesn't  follow this convention.

有人可以解释如何解决这个问题吗?

1 个答案:

答案 0 :(得分:5)

你可能想制作scikit-learn兼容模型,进一步使用scikit-learn功能。如果你这样做 - 你需要先阅读: http://scikit-learn.org/stable/developers/contributing.html#rolling-your-own-estimator

很快:scikit-learn有许多功能,如估算器克隆(clone()函数),元算法,如GridSearchPipeline,交叉验证。所有这些都必须能够获得估算器内部字段的值,并更改这些字段的值(例如GridSearch必须在每次评估之前更改估算器内的参数),如参数{{1在alpha中。要更改某个参数的值,必须知道它的名称。要从SGDClassifier类(您隐式继承)中获取每个分类器方法get_params中所有字段的名称,需要在类的BaseEstimator方法中指定所有参数,因为它很容易内省__init__方法的所有参数名称(查看__init__,这是抛出此错误的类。)

所以它只是希望你删除像

这样的所有变种
BaseEstimator

来自*args, **kwargs 签名。您必须在__init__签名中列出模型的所有参数,并初始化对象的所有内部字段。

以下是SGDClassifier __init__方法的示例,该方法继承自__init__

BaseSGDClassifier