我正在学习如何在scikit-learn中创建和使用管道。我尝试为LabelEncoder
创建一个管道,后跟OneHotEncoder
。
当我在链中单独运行变压器时,它工作正常。但是当我尝试在管道中添加它时,它会出错。
这是我的代码。
import pandas as pd
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.pipeline import Pipeline
from sklearn.pipeline import make_pipeline
data = [
"cat1",
"cat2",
"cat3"
]
df = pd.DataFrame(data, columns=["category"])
class DataFrameSelector(BaseEstimator, TransformerMixin):
def __init__(self, attribute_names):
self.attribute_names = attribute_names
def fit(self, X, y=None):
return self
def transform(self, X):
return X[self.attribute_names].values.flatten()
class Reshaper(BaseEstimator, TransformerMixin):
def __init__(self):
pass
def fit(self, X, y=None):
return self
def transform(self, X):
return X.reshape(-1, 1)
cat_attribs = ["category"]
dfs = DataFrameSelector(cat_attribs)
le = LabelEncoder()
rs = Reshaper()
ohe = OneHotEncoder()
### this works
transformed = ohe.fit_transform(
rs.fit_transform(
le.fit_transform(
dfs.fit_transform(df)
)
)
)
transformed.toarray()
###
# cat_pipeline = Pipeline([
# ('selector', dfs),
# ('label_encoder', le),
# ('reshaper', rs),
# ('cat_encoder', ohe),
# ])
# transformed = cat_pipeline.fit_transform(df)
如果我尝试运行已注释的代码,则会显示错误fit_transform() takes 2 positional arguments but 3 were given
。
出了什么问题?
P.S。我可以使用LabelBinarizer完成此任务。但是当我偶然发现这种意想不到的行为时,我只是在探索管道。