结合Sklearn管道中的功能

时间:2019-09-24 08:22:36

标签: python scikit-learn

我想使用包含TfidfVectorizerSVC的管道。但是,在这两者之间,我想将从非文本数据中提取的某些功能连接到TfidfVectorizer的输出中。

我尝试创建自定义类(基于此tutorial的方法)来执行此操作,但这似乎不起作用。

这是我到目前为止尝试过的:

pipeline = Pipeline([
    ('tfidf', TfidfVectorizer()),
    ('transformer', CustomTransformer(one_hot_feats)),
    ('clf', MultinomialNB()),
])

parameters = {
    'tfidf__min_df': (5, 10, 15, 20, 25, 30),
    'tfidf__max_df': (0.8, 0.9, 1.0),
    'tfidf__ngram_range': ((1, 1), (1, 2)),
    'tfidf__norm': ('l1', 'l2'),
    'clf__alpha': np.linspace(0.1, 1.5, 15),
    'clf__fit_prior': [True, False],
}

grid_search = GridSearchCV(pipeline, parameters, cv=5, n_jobs=-1, verbose=1)
grid_search.fit(df["short description"], labels)

这是CustomTransformer

class CustomTransformer(TransformerMixin):
"""Class that concatenates the one hot encode category feature with the tfidf data."""

def __init__(self, one_hot_features):
    """Initializes an instance of our custom transformer."""
    self.one_hot_features = one_hot_features

def fit(self, X, y=None, **kwargs):
    """Dummy fit function that does nothing particular."""

    return self

def transform(self, X, y=None, **kwargs):
    """Adds our external features"""
    return numpy.hstack((one_hot_feats, X))   

只要X不会更改自定义类中的尺寸(可能是与TransformerMixin相关的限制),此方法就可以工作,但是,就我而言,我将在数据中附加其他功能。我的自定义类是否应该继承自其他基类,或者有其他方法可以解决此问题?

1 个答案:

答案 0 :(得分:2)

您可以使用Sklearn的FeatureUnion组合多个功能,并使用 ColumnTransformer 转换特定的列:

来自文档

  

功能联盟

     

连接多个转换器对象的结果。

     

此估算器将一系列变压器对象并行应用于   输入数据,然后将结果连接起来。这对   将多种特征提取机制组合为一个   变压器。

     

ColumnTransformer

     

将转换器应用于数组或熊猫DataFrame的列。

     

此估算器允许输入的不同列或列子集   分别进行变换,并分别生成特征   转换器将被串联以形成单个特征空间。这个   对于异构或列式数据很有用,可以将多个   特征提取机制或转换为单个   变压器。

您可以使用make_column_transformer

from sklearn.compose import make_column_transformer
pipeline = Pipeline([
    ('transformer',  make_column_transformer((TfidfVectorizer(), ['text_column']),
                                             (OneHotEncoder(), ['categorical_column']),)),
    ('clf', MultinomialNB()),
])

编辑:

make_column_transformer中将remainder设置为'passthrough',以便所有未在转换器中指定的其余列都将自动通过。