实现自定义scikit-learn估算器的完整规范是什么?

时间:2014-05-26 09:22:22

标签: python scikit-learn

我正在使用我自己的预测器,并希望像使用任何scikit例程一样使用它(例如RandomForestRegressor)。我有一个包含fitpredict方法的类似乎工作正常。但是,当我尝试使用某些scikit方法时,例如交叉验证,我会收到如下错误:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Python27\lib\site-packages\sklearn\cross_validation.py", line 1152, in cross_val_
score
    for train, test in cv)
  File "C:\Python27\lib\site-packages\sklearn\externals\joblib\parallel.py", line 516, in __
call__
    for function, args, kwargs in iterable:
  File "C:\Python27\lib\site-packages\sklearn\cross_validation.py", line 1152, in <genexpr>
    for train, test in cv)
  File "C:\Python27\lib\site-packages\sklearn\base.py", line 43, in clone
    % (repr(estimator), type(estimator)))
TypeError: Cannot clone object '<__main__.Custom instance at 0x033A6990>' (type <type 'inst
ance'>): it does not seem to be a scikit-learn estimator a it does not implement a 'get_para
ms' methods.

我看到它希望我实现一些方法(大概是get_params以及set_paramsscore)但是我不确定制作这些方法的正确规范是什么是。是否有关于此主题的信息?感谢。

1 个答案:

答案 0 :(得分:11)

scikit-learn docs中提供了完整说明,API背后的原则在this paper by yours truly et al.中列出。简而言之,除了fit之外,估算工具所需要的是{{1} }和get_params返回(作为set_params)并设置(来自kwargs)估计器的超参数,即学习算法本身的参数(与其学习的数据参数相反)。这些参数应与dict参数匹配。

这两种方法都可以通过继承__init__中的类来获得,但如果您不希望您的代码依赖于scikit-learn,则可以自己提供。

请注意,输入验证应该在sklearn.base而不是构造函数中完成,否则您仍然可以在fit中设置无效参数,并使set_params以意外方式失败。