scikit学习。多处理管道中的自定义估算器,n_jobs = -1

时间:2017-08-27 08:38:28

标签: python python-3.x scikit-learn

我写了自定义变换器并在scikit中构建了管道学习。现在我正在尝试使用GridSearchCV调整此管道。一切正常,直到我试图把n_jobs = -1加速进程。

GUI Jupyter笔记本没有写任何问题,只显示内核正忙,但在控制台内部打印出以下错误多次复制: Jupyter notebook consone log Drop Fields是我的一个自定义变换器的名称(不知道这是否重要,但它是管道的第一步)。它的定义如下:

class DropFields(FieldsTransformerMixin, Transformer):  
    def __init__(self, fields=None, all_except=False):
        self.fields = fields
        self.all_except = all_except

    def _fit_before(self, data):
        self.fields_ = list(set(data.columns) - set(self.fields)) if self.all_except else list(self.fields)

    def _transform_before(self, data):
        return  data.drop(self.fields_, axis=1)

父母(在上面的单元格中定义的笔记本中):

class Transformer(BaseEstimator, TransformerMixin):
    pass

class FieldsTransformerMixin:
    def __init__(self, fields=None):
        self.fields = fields 

    def fit(self, data, y=None):
        self._validate_params()
        self._fit_before(data)
        for field in self.fields:
            self._fit_field(field, data)
        return self

    def transform(self, data):
        data = data.copy()
        data = self._transform_before(data)
        for field in self.fields:
            data = self._transform_field(field, data)
        return data

    def _validate_params(self):
        if self.fields is None:
            raise ValueError('Fields is none.')

    ... empty definitions of _fit_before,_fit_field,
    ... definitions of _transform_before and _transform_field returning default data

问题是:

我需要在自定义估算器中实现特定逻辑,以便将它们与 n_jobs = -1 一起使用,如果没有,那么这里有什么问题?为什么多处理无法找到“DropFields”?

1 个答案:

答案 0 :(得分:1)

scikit-learn估计器需要支持克隆以进行交叉验证,以使用n_jobs = -1对其进行多处理。您可以查看其文档here。 base.clone接受您的自定义估算器类,调用其get_params()函数,并通过使用与从原始类的get_params()方法中检索到的相同参数重复调用类初始化器来创建副本。

您的自定义估算器类也必须是“可拾取的”。如果您的自定义估算器类未在模块的根目录中定义,则将导致错误。估计器类需要获取其所需的所有内容作为参数,并与模块分开运行。