在scikit-learn Pipeline中获取中间数据状态

时间:2018-02-12 09:22:53

标签: python scikit-learn

给出以下示例:

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import NMF
from sklearn.pipeline import Pipeline
import pandas as pd

pipe = Pipeline([
    ("tf_idf", TfidfVectorizer()),
    ("nmf", NMF())
])

data = pd.DataFrame([["Salut comment tu vas", "Hey how are you today", "I am okay and you ?"]]).T
data.columns = ["test"]

pipe.fit_transform(data.test)

我想获得与tf_idf输出相对应的scikit learn管道中的中间数据状态(在tf_idf上的fit_transform但不是NMF之后)或NMF输入。 或者用另一种方式说话,这与申请

相同
TfidfVectorizer().fit_transform(data.test)

我知道pipe.named_steps [“tf_idf”] ti得到中间变换器,但我无法获取数据,只能用这种方法得到变压器的参数。

3 个答案:

答案 0 :(得分:5)

正如@Vivek Kumar在评论中建议的那样,当我回答here时,我找到了一个打印信息或将中间数据帧写入csv有用的调试步骤:

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import NMF
from sklearn.pipeline import Pipeline
import pandas as pd
from sklearn.base import TransformerMixin, BaseEstimator


class Debug(BaseEstimator, TransformerMixin):

    def transform(self, X):
        print(X.shape)
        self.shape = shape
        # what other output you want
        return X

    def fit(self, X, y=None, **fit_params):
        return self

pipe = Pipeline([
    ("tf_idf", TfidfVectorizer()),
    ("debug", Debug()),
    ("nmf", NMF())
])

data = pd.DataFrame([["Salut comment tu vas", "Hey how are you today", "I am okay and you ?"]]).T
data.columns = ["test"]

pipe.fit_transform(data.test)

修改

我现在在调试转换器中添加了一个状态。现在,你可以通过@datasailor在答案中访问形状:

pipe.named_steps["debug"].shape

答案 1 :(得分:4)

据我了解,您希望获得转换后的训练数据。您已经在pipe.named_steps["tf_idf"]中拟合了数据,因此只需使用此拟合模型再次转换训练数据:

pipe.named_steps["tf_idf"].transform(data.test)

答案 2 :(得分:0)

我为此创建了一个 gist。本质上,从 Python 3.2 开始,使用 Context Manager,下面的代码允许将中间结果检索到以管道转换器的名称作为键的字典中。

with intermediate_transforms(pipe):
    Xt = pipe.transform(X)
    intermediate_results = pipe.intermediate_results__

这是通过下面的函数完成的,但请参阅我的要点以获取更多文档。

import contextlib
from functools import partial

from sklearn.pipeline import Pipeline

@contextlib.contextmanager
def intermediate_transforms(pipe: Pipeline):
    # Our temporary overload of Pipeline._transform() method.
    # https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/pipeline.py
    def _pipe_transform(self, X):
        Xt = X
        for _, name, transform in self._iter():
            Xt = transform.transform(Xt)
            self.intermediate_results__[name] = Xt
        return Xt

    if not isinstance(pipe, Pipeline):
        raise ValueError(f'"{pipe}" must be a Pipeline.')

    pipe.intermediate_results__ = {}                              
    _transform_before = pipe._transform
    pipe._transform = partial(_pipe_transform, pipe)  # Monkey-patch our _pipe_transform method.
    yield pipe  # Release our patched object to the context
    
    # Restore
    pipe._transform = _transform_before
    delattr(pipe, 'intermediate_results__')