sklearn 管道中变量的名称

时间:2021-07-12 19:25:57

标签: python scikit-learn sklearn-pandas

我需要使用 sklearn 库中的 DecisionTreeClassifier。我的数据集中有多个列 我必须假装。我的问题是我在结果模型中有变量名 特征_1、特征_2、...、特征_n的非语音名称。我如何给他们起真名?我使用大约 400 列的数据集,因此手动重命名不是理想的方法。谢谢。

import pandas as pd

from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.model_selection import train_test_split, cross_val_score
from yellowbrick.model_selection import RFECV


raw_data = {'sum': [2345, 256,  43, 643, 34 , 23, 95], 
        'department': ['a1', 'a1', 'a3', 'a3', 'a1', 'a2', 'a2'],
        'sex': ['m', 'neudane', 'f', '', 'f', 'f', 'f']}
df = pd.DataFrame(raw_data, columns = ['sum', 'department', 'sex'])

y = {'y': ['cat_a', 'cat_a', 'cat_b', 'cat_c', 'cat_b', 'cat_a', 'cat_a']}

y = pd.DataFrame(y, columns = ['y'])


categorical = ['department', 'sex']

numerical = ['sum']


X = df[categorical + numerical]


categorical_pipeline = Pipeline([
    ("imputer", SimpleImputer(strategy="most_frequent")),
    ("encoder", OneHotEncoder(sparse=True, handle_unknown="ignore"))
])

numerical_pipeline = Pipeline([
    ("imputer", SimpleImputer(strategy="mean")),
    ("scaler", StandardScaler())
])



basic_preprocessor = ColumnTransformer([
    #("nominal_preprocessor", nominal_pipeline, nominal),
    ("categorical_preprocessor", categorical_pipeline, categorical),
    ("numerical_preprocessor", numerical_pipeline, numerical)
])


preprocessed = basic_preprocessor.fit_transform(X)


X = preprocessed


from sklearn.model_selection import train_test_split
train, test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

from sklearn import tree
from sklearn.tree import export_text
clf = tree.DecisionTreeClassifier()
clf = clf.fit(train, y_train)


r = export_text(clf)
print(r)



>>>r = export_text(clf)
>>>print(r)
|--- feature_1 <= 0.50
|   |--- feature_7 <= -0.19
|   |   |--- class: cat_b
|   |--- feature_7 >  -0.19
|   |   |--- class: cat_c
|--- feature_1 >  0.50
|   |--- class: cat_a

0 个答案:

没有答案