在scikit-learn中添加自建词汇?

时间:2016-04-05 06:25:46

标签: optimization scikit-learn feature-detection vocabulary sklearn-pandas

sklearn.feature_extraction.text.TfidfVectorizer中,我们可以使用模型的vocabulary参数注入我们自己的词汇表。但在这种情况下,只有我自己选择的单词用于模型。

我想在自定义词汇表中使用自动检测到的功能。

解决此问题的一种方法是创建模型并使用

获取功能
vocab=vectorizer.get_feature_names()

将我的列表添加到词汇

vocab + vocabulary

再次构建模型。

有没有办法一步完成整个过程?

1 个答案:

答案 0 :(得分:2)

我认为没有比实现你想要的更简单的方法了。您可以做的一件事是使用 CountVectorizer 的代码来创建词汇表。我浏览了源代码,方法是

_count_vocab(self, raw_documents, fixed_vocab)

使用fixed_vocab=False调用。

所以我建议您在运行TfidfVectorizer之前调整以下代码(Source)来创建词汇表。

def _count_vocab(self, raw_documents, fixed_vocab):
        """Create sparse feature matrix, and vocabulary where fixed_vocab=False
        """
        if fixed_vocab:
            vocabulary = self.vocabulary_
        else:
            # Add a new value when a new vocabulary item is seen
            vocabulary = defaultdict()
            vocabulary.default_factory = vocabulary.__len__

        analyze = self.build_analyzer()
        j_indices = _make_int_array()
        indptr = _make_int_array()
        indptr.append(0)
        for doc in raw_documents:
            for feature in analyze(doc):
                try:
                    j_indices.append(vocabulary[feature])
                except KeyError:
                    # Ignore out-of-vocabulary items for fixed_vocab=True
                    continue
            indptr.append(len(j_indices))

        if not fixed_vocab:
            # disable defaultdict behaviour
            vocabulary = dict(vocabulary)
            if not vocabulary:
                raise ValueError("empty vocabulary; perhaps the documents only"
                                 " contain stop words")

        j_indices = frombuffer_empty(j_indices, dtype=np.intc)
        indptr = np.frombuffer(indptr, dtype=np.intc)
        values = np.ones(len(j_indices))

        X = sp.csr_matrix((values, j_indices, indptr),
                          shape=(len(indptr) - 1, len(vocabulary)),
                          dtype=self.dtype)
        X.sum_duplicates()
        return vocabulary, X