我在developer guide之后实施了几个自定义估算工具,因此所有这些都是从BaseEstimator继承的。其中一些使用其他scikit-learn估计器或变换器作为属性(例如,建立一个整体)。从BaseEstimator继承应该让我方便的是通过get_params()访问参数,并通过set_params()设置它们,如here所述,形式为component__parameter,例如用于网格搜索。在下面找到一个最小的例子。
from sklearn.base import BaseEstimator
from sklearn.linear_model import LinearRegression
class MyForecaster(BaseEstimator):
def __init__(self, base_estimator=LinearRegression()):
self.base_estimator = base_estimator
def fit(self, X, y):
pass
def predict(self, X, y):
pass
# instantiate forecaster and set parameters
mf = MyForecaster()
mf.set_params(**{"base_estimator" : "ElasticNet", "base_estimator__alpha": 0.05})
这失败了:
ValueError: Invalid parameter alpha for estimator LinearRegression. Check the list of available parameters with `estimator.get_params().keys()`.
这表示它尝试首先为嵌套属性设置params,而不是首先检查是否要覆盖“更高级别”属性(ElasticNet具有属性alpha,而不是LinearRegression)。
处理此问题的一种方法是为每个估算器覆盖set_params(),以确保正确处理它。
是否有任何“内置”方式来实现这一目标,我忽略了另一种解决方案?这是scikit-learn真正意图的行为吗?
所以确实由于一些非常大的巧合,一个非常类似的问题似乎已经修复了版本0.19.1。但是,我的特殊情况仍然失败,只有管道的情况是固定的!
为了使其可重现,我将set_params()的当前代码复制到我的最小示例中(仅在第20行添加了注释)
1 def set_params(self, **params):
2 if not params:
3 # Simple optimization to gain speed (inspect is slow)
4 return self
5 valid_params = self.get_params(deep=True)
6
7 nested_params = defaultdict(dict) # grouped by prefix
8 for key, value in params.items():
9 key, delim, sub_key = key.partition('__')
10 if key not in valid_params:
11 raise ValueError('Invalid parameter %s for estimator %s. '
12 'Check the list of available parameters '
13 'with `estimator.get_params().keys()`.' %
14 (key, self))
15
16 if delim:
17 nested_params[key][sub_key] = value
18 else:
19 setattr(self, key, value)
20 #valid_params[key] = value
21
22 for key, sub_params in nested_params.items():
23 valid_params[key].set_params(**sub_params)
24
25 return self
它失败了,因为它将在第19行设置属性,但由于它不更新valid_params,因此在尝试设置属性时,它在下一次迭代中仍然会失败。所以我添加了第20行来解决这个问题。 它确实在0.19.1的当前修订中测试,因为它仅针对管道进行了测试。 Here,set_param()会被覆盖,以便首先调用_BaseComposition的_set_param(),其中可以处理这个问题。
我应该在scikit-learn github中提出这个问题还是重新打开另一个问题?
答案 0 :(得分:1)
这是一个错误。已报告a week ago,已been fixed已向v0.19.1
移出v0.19.1
,已发布yesterday。
最简单的解决方法是将scikit-learn更新为matrix4(matrix4& m)
(或掌握dev分支)。
答案 1 :(得分:1)