Scikit-Learn:自定义估计器的set_param()在组件之前设置嵌套参数

时间:2017-10-24 16:43:10

标签: scikit-learn

我在developer guide之后实施了几个自定义估算工具,因此所有这些都是从BaseEstimator继承的。其中一些使用其他scikit-learn估计器或变换器作为属性(例如,建立一个整体)。从BaseEstimator继承应该让我方便的是通过get_params()访问参数,并通过set_params()设置它们,如here所述,形式为component__parameter,例如用于网格搜索。在下面找到一个最小的例子。

from sklearn.base import BaseEstimator
from sklearn.linear_model import LinearRegression

class MyForecaster(BaseEstimator):

    def __init__(self, base_estimator=LinearRegression()):
        self.base_estimator = base_estimator

    def fit(self, X, y):
        pass

    def predict(self, X, y):
        pass

# instantiate forecaster and set parameters
mf = MyForecaster()
mf.set_params(**{"base_estimator" : "ElasticNet", "base_estimator__alpha": 0.05})

这失败了:

ValueError: Invalid parameter alpha for estimator LinearRegression. Check the list of available parameters with `estimator.get_params().keys()`.

这表示它尝试首先为嵌套属性设置params,而不是首先检查是否要覆盖“更高级别”属性(ElasticNet具有属性alpha,而不是LinearRegression)。

处理此问题的一种方法是为每个估算器覆盖set_params(),以确保正确处理它。

是否有任何“内置”方式来实现这一目标,我忽略了另一种解决方案?这是scikit-learn真正意图的行为吗?

编辑:

所以确实由于一些非常大的巧合,一个非常类似的问题似乎已经修复了版本0.19.1。但是,我的特殊情况仍然失败,只有管道的情况是固定的!

为了使其可重现,我将set_params()的当前代码复制到我的最小示例中(仅在第20行添加了注释)

1   def set_params(self, **params):
2       if not params:
3           # Simple optimization to gain speed (inspect is slow)
4           return self
5       valid_params = self.get_params(deep=True)
6
7       nested_params = defaultdict(dict)  # grouped by prefix
8       for key, value in params.items():
9           key, delim, sub_key = key.partition('__')
10          if key not in valid_params:
11              raise ValueError('Invalid parameter %s for estimator %s. '
12                               'Check the list of available parameters '
13                               'with `estimator.get_params().keys()`.' %
14                               (key, self))
15
16          if delim:
17              nested_params[key][sub_key] = value
18          else:
19              setattr(self, key, value)
20              #valid_params[key] = value
21
22      for key, sub_params in nested_params.items():
23          valid_params[key].set_params(**sub_params)
24
25      return self

它失败了,因为它将在第19行设置属性,但由于它不更新valid_params,因此在尝试设置属性时,它在下一次迭代中仍然会失败。所以我添加了第20行来解决这个问题。 它确实在0.19.1的当前修订中测试,因为它仅针对管道进行了测试。 Here,set_param()会被覆盖,以便首先调用_BaseComposition的_set_param(),其中可以处理这个问题。

我应该在scikit-learn github中提出这个问题还是重新打开另一个问题?

2 个答案:

答案 0 :(得分:1)

这是一个错误。已报告a week ago,已been fixed已向v0.19.1移出v0.19.1,已发布yesterday

最简单的解决方法是将scikit-learn更新为matrix4(matrix4& m)(或掌握dev分支)。

答案 1 :(得分:1)

因此,@ TomDLT的答案中提到的fix修复了一个非常类似的问题,导致上面的修复最有可能成为未来版本的sklearn(9999)。

所以对于这里:如果您在此期间遇到问题,请使用上面的代码覆盖set_params()或等待修复。