scikit学习如何检查模型(例如TfidfVectorizer)是否已经适合

时间:2018-07-16 20:33:25

标签: python numpy machine-learning scikit-learn

要从文本中提取特征,如何检查矢量数据(例如TfIdfVectorizer或CountVectorizer)是否已经适合训练数据?
特别是,我希望代码自动确定矢量化器是否已经适合。

from sklearn.feature_extraction.text import TfidfVectorizer

vectorizer = TfidfVectorizer()

def vectorize_data(texts):
  # if vectorizer has not been already fit
  vectorizer.fit_transform(texts)
  # else
  vectorizer.transform(texts)

2 个答案:

答案 0 :(得分:3)

您可以使用check_is_fitted来完成此操作。

source of TfidfVectorizer.transform()中,您可以检查其用法:

def transform(self, raw_documents, copy=True):

    # This is what you need.
    check_is_fitted(self, '_tfidf', 'The tfidf vector is not fitted')

    X = super(TfidfVectorizer, self).transform(raw_documents)
    return self._tfidf.transform(X, copy=False)

因此,您可以这样做:

from sklearn.utils.validation import check_is_fitted

def vectorize_data(texts):

    try:
        check_is_fitted(vectorizer, '_tfidf', 'The tfidf vector is not fitted')
    except NotFittedError:
        vectorizer.fit(texts)

    # In all cases vectorizer if fit here, so just call transform()
    vectorizer.transform(texts)

答案 1 :(得分:2)

我提出了两种检查方法:

涵盖所有scikit学习模型的个人代码:

import inspect

def my_inspector(model):
    return 0 < len( [k for k,v in inspect.getmembers(model) if k.endswith('_') and not k.startswith('__')] )

现在让我们测试以下代码:

from sklearn.feature_extraction.text import TfidfVectorizer
import inspect

vectorizer = TfidfVectorizer()

def my_inspector(model):
        return 0 < len( [k for k,v in inspect.getmembers(model) if k.endswith('_') and not k.startswith('__')] )

my_inspector(vectorizer)
# False

使用check_is_fitted

的第二种方法
from sklearn.utils.validation import check_is_fitted

check_is_fitted(vectorizer, '_tfidf', 'The tfidf vector is not fitted')