如何从管道中的sklearn TFIDF矢量化器返回数据帧?

时间:2018-10-11 11:37:41

标签: python pandas scikit-learn tf-idf

如何让TFIDF Vectorizer在用于交叉验证的sklearn管道内返回具有相应列名称的熊猫数据框?

我有一条Sklearn管道,其中一个步骤是TFIDF矢量化器:

class InspectPipeline(BaseEstimator, TransformerMixin):

    def transform(self, x):
        return x

    def fit(self, x, y=None):
        self.df = x
        return self


pipeline = Pipeline(
        [
         ("selector", ItemSelector(key="text_column")),
         ("vectorizer", TfidfVectorizer()),
         ("debug", InspectPipeline()),
         ("classifier", RandomForestClassifier())
        ]
)

我创建了类InspectPipeline,以便稍后检查传递给分类器的功能(通过运行pipeline.best_estimator_.named_steps['debug'].df)。但是,TfidfVectorizer返回一个稀疏矩阵,这是我执行pipeline.best_estimator_.named_steps['debug'].df时得到的。我不想获得稀疏矩阵,而是希望将TFIDF向量作为熊猫数据帧获得,其中列名称分别是tfidf令牌。

我知道tfidf_vectorizer.get_feature_names()可以帮助您了解列名。但是,如何在流水线中包含将稀疏矩阵转换为数据帧的功能呢?

2 个答案:

答案 0 :(得分:3)

您可以扩展TfidfVectorizer来返回带有所需列名的DataFrame,然后在管道中使用它。

from sklearn.feature_extraction.text import TfidfVectorizer
import pandas as pd

class DenseTfidfVectorizer(TfidfVectorizer):

    def transform(self, raw_documents, copy=True):
        X = super().transform(raw_documents, copy=copy)
        df = pd.DataFrame(X.toarray(), columns=self.get_feature_names())
        return df

    def fit_transform(self, raw_documents, y=None):
        X = super().fit_transform(raw_documents, y=y)
        df = pd.DataFrame(X.toarray(), columns=self.get_feature_names())
        return df

答案 1 :(得分:0)

根据docs,您可以使用以下方法

a。直接访问管道外部的.get_feature_names()并检查那里的数据框(带有命名列)

b。 apply .fit_transform on data在管道之外

pipeline = Pipeline(....)

# a. extract .get_feature_names() to use as column names in the dataframe
feature_names = (
                pipeline.best_estimator_
                        .named_steps['vectorizer']
                        .get_feature_names()
                )    

# b. get the TFIDF vector
data2 = (
         pipeline.best_estimator_
                 .named_steps['vectorizer']
                 .fit_transform(raw_data)
        )

# put into a pandas dataframe
transformed = pd.DataFrame(data2, columns=feature_names)

这样,您也许可以完全跳过管道中的debug步骤,并检查管道外部的数据帧。