在管道中使用时发生变压器错误

时间:2018-06-12 08:46:48

标签: python machine-learning scikit-learn

我正在学习如何在scikit-learn中创建和使用管道。我尝试为LabelEncoder创建一个管道,后跟OneHotEncoder

当我在链中单独运行变压器时,它工作正常。但是当我尝试在管道中添加它时,它会出错。

这是我的代码。

import pandas as pd
import numpy as np

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.pipeline import Pipeline
from sklearn.pipeline import make_pipeline

data = [
    "cat1",
    "cat2",
    "cat3"
]

df = pd.DataFrame(data, columns=["category"])

class DataFrameSelector(BaseEstimator, TransformerMixin):
    def __init__(self, attribute_names):
        self.attribute_names = attribute_names

    def fit(self, X, y=None):
        return self

    def transform(self, X):
        return X[self.attribute_names].values.flatten()

class Reshaper(BaseEstimator, TransformerMixin):
    def __init__(self):
        pass

    def fit(self, X, y=None):
        return self

    def transform(self, X):
        return X.reshape(-1, 1)

cat_attribs = ["category"]

dfs = DataFrameSelector(cat_attribs)
le = LabelEncoder()
rs = Reshaper()
ohe = OneHotEncoder()

### this works

transformed = ohe.fit_transform(
    rs.fit_transform(
        le.fit_transform(
            dfs.fit_transform(df)
        )
    )
)

transformed.toarray()

###

# cat_pipeline = Pipeline([
#     ('selector', dfs),
#     ('label_encoder', le),
#     ('reshaper', rs),
#     ('cat_encoder', ohe),
# ])

# transformed = cat_pipeline.fit_transform(df)

如果我尝试运行已注释的代码,则会显示错误fit_transform() takes 2 positional arguments but 3 were given

出了什么问题?

P.S。我可以使用LabelBinarizer完成此任务。但是当我偶然发现这种意想不到的行为时,我只是在探索管道。

0 个答案:

没有答案