使用sklearn OneHotEncoder时如何忽略数字列?

时间:2020-03-22 05:10:18

标签: python pandas scikit-learn one-hot-encoding

环境:

import pandas as pd
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.ensemble import RandomForestClassifier

样本数据:

X_train = pd.DataFrame({'A': ['a1', 'a3', 'a2'], 
                        'B': ['b2', 'b1', 'b3'],
                        'C': [1, 2, 3]})
y_train = pd.DataFrame({'Y': [1,0,1]})

所需结果: 我想以以下格式在管道中包含sklearn OneHotEncoder:

encoder = ### SOME CODE ###
scaler = StandardScaler()
model = RandomForestClassifier(random_state=0)

# This is my ideal pipeline
pipe = Pipeline([('OneHotEncoder', encoder),
                 ('Scaler', scaler),
                 ('Classifier', model)])
pipe.fit(X_train, y_train)

挑战: OneHotEncoder正在对包括数字列在内的所有内容进行编码。我想保持数值列不变,并以与Pipeline()兼容的有效方式仅对分类特征进行编码。

encoder = OneHotEncoder(drop='first', sparse=False) 
encoder.fit(X_train)
encoder.transform(X_train) # Columns C is encoded - this is what I want to avoid

解决方法(不理想):我可以使用pd.get_dummies()解决该问题。但是,这意味着我无法将其包含在我的管道中。还是有办法?

X_train = pd.get_dummies(X_train, drop_first=True)

2 个答案:

答案 0 :(得分:1)

我要做的是创建自己的自定义转换器并将其放入管道中。这样,您手中的数据将具有很大的威力。因此,步骤如下:

1)创建一个继承BaseEstimatorTransformerMixin的自定义转换器类。在其transform()函数中,尝试检测该列的值是数字的还是分类的。如果您现在不想处理逻辑,则始终可以将分类列的列名称提供给transform()函数,以便随时进行选择。

2)(可选)创建自定义转换器,以仅处理分类值的列。

3)(可选)创建您的自定义转换器,以仅处理具有数值的列。

4)使用您创建的转换器构建两个管道(一个用于分类,另一个用于数字),您也可以使用sklearn中的现有管道。

5)用FeatureUnion合并两个管道。

6)将大型管道与ML模型合并。

7)致电fit_transform()

示例代码(未实现可选选项):GitHub Jupyter Noteboook

答案 1 :(得分:0)

对此,我的首选解决方案是使用sklearn的ColumnTransformer(请参阅here)。

它使您可以根据需要将数据分为任意多个组(在您的情况下是分类数据还是数字数据),并对这些组应用不同的预处理操作。然后,该变压器可以像其他任何sklearn预处理工具一样在管道中使用。这是一个简短的示例:

import pandas as pd
import numpy as np
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier

X = pd.DataFrame({"a":[1,2,3],"b":["A","A","B"]})
y = np.array([0,1,1])

OHE = OneHotEncoder()
scaler = StandardScaler()
RFC = RandomForestClassifier()

cat_cols = ["b"]
num_cols = ["a"]

transformer = ColumnTransformer([('cat_cols', OHE, cat_cols),
                                ('num_cols', scaler, num_cols)])

pipe = Pipeline([("preprocessing", transformer),
                ("classifier", RFC)])
pipe.fit(X,y)

注意:我已根据您的请求获得许可,因为这仅将定标器应用于数字数据,我认为这更有意义吗?如果确实要将缩放器应用于所有列,则也可以通过修改此示例来做到这一点。