如何将关键字参数传递给sklearn管道中的predict方法

时间:2014-08-26 15:29:12

标签: python scikit-learn

我在GaussianProcess内使用Pipelinepredict的{​​{1}}方法接受名为GaussianProcess的{​​{1}}方法的关键字参数,我需要使用该方法来填补我的记忆。

有没有办法在通过配置的管道调用predict时将此参数传递给batch_size实例?

这是一个根据sklearn文档改编的最小示例,用于演示我想要的内容:

GaussianProcess

3 个答案:

答案 0 :(得分:1)

虽然可以向管道的fitfit_transform方法添加fit-parameters,但predict无法实现。请参阅版本0.15代码中的this line及随后的版本。

您可以使用

对其进行monkeypatch
from functools import partial
gp.predict = partial(gp.predict, batch_size=10)

或者,如果那不起作用,那么

pipeline.steps[-1][-1].predict = partial(pipeline.steps[-1][-1].predict, batch_size=10)

答案 1 :(得分:1)

正在寻找-在其他地方找到了答案,但想在这里分享,因为这是我发现的第一个相关问题。

sklearn的当前版本中,您可以将关键字args传递给管道,然后将其传递给预测变量(即管道中的最后一个元素):

from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.pipeline import Pipeline

class Predictor(BaseEstimator, ClassifierMixin):
    def fit(self, *_):
        return self

    def predict(self, X, additional_arg):
        return f'ok: {X}, {additional_arg}'

pipe = Pipeline([
    ('passthrough', 'passthrough'),  # anything here would *not* see the keyword arg
    ('p', Predictor())
])

print(Predictor().predict('one', 'two'))
print(pipe.predict('three', additional_arg='four'))  # must be passed as keyword argument

# DO NOT:
print(pipe.predict('three')) # would raise an exception: missing parameter
print(pipe.predict('three', 'four')) # would raise an exception: no positional args allowed

答案 2 :(得分:0)

您可以通过允许将关键字参数**predict_params传递给管道的预测方法来解决此问题。

from sklearn.pipeline import Pipeline

class Pipeline(Pipeline):
    def predict(self, X, **predict_params):
        """Applies transforms to the data, and the predict method of the
        final estimator. Valid only if the final estimator implements
        predict."""
        Xt = X
        for name, transform in self.steps[:-1]:
            Xt = transform.transform(Xt)
        return self.steps[-1][-1].predict(Xt, **predict_params)