Python TfidfVectorizer:条件重新初始化是否可行?

时间:2015-12-10 16:11:53

标签: python nlp scikit-learn

我试图有条件地重新初始化一个对象

假设我有以下初始化

 TfidfVectorizer(sublinear_tf=True , decode_error='ignore', analyzer='word', tokenizer=nltk.data.load('tokenizers/punkt/english.pickle'))

现在,我从一个用户那里获得了一个dict,他想要添加一些参数

 d = {"stop_words":"english"}

如何将dict参数添加到已初始化的对象?所以对象的最终版本将是等价的

TfidfVectorizer(
                             stop_words='english',
                             sublinear_tf=True ,
                             decode_error='ignore',
                             analyzer='word',
                             tokenizer=nltk.data.load('tokenizers/punkt/english.pickle'))
我能做到吗

TfidfVectorizer(**d)

这会保留以前初始化的参数吗?我想在TfidfVectorizer中有一些默认设置,然后我希望用户能够选择其余的。

是这样的吗?

1 个答案:

答案 0 :(得分:1)

通过set_params()set_params()这个小实验,使用get_params()显示

from sklearn.feature_extraction.text import TfidfVectorizer

t = TfidfVectorizer()

t.get_params()
Out[23]: 
{'analyzer': u'word',
 'binary': False,
 'charset': None,
 'charset_error': None,
 'decode_error': u'strict',
 'dtype': numpy.int64,
 'encoding': u'utf-8',
 'input': u'content',
 'lowercase': True,
 'max_df': 1.0,
 'max_features': None,
 'min_df': 1,
 'ngram_range': (1, 1),
 'norm': u'l2',
 'preprocessor': None,
 'smooth_idf': True,
 'stop_words': None,
 'strip_accents': None,
 'sublinear_tf': False,
 'token_pattern': u'(?u)\\b\\w\\w+\\b',
 'tokenizer': None,
 'use_idf': True,
 'vocabulary': None}

t.set_params(binary=True)
Out[24]: 
TfidfVectorizer(analyzer=u'word', binary=True, charset=None,
        charset_error=None, decode_error=u'strict',
        dtype=<type 'numpy.int64'>, encoding=u'utf-8', input=u'content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(1, 1), norm=u'l2', preprocessor=None, smooth_idf=True,
        stop_words=None, strip_accents=None, sublinear_tf=False,
        token_pattern=u'(?u)\\b\\w\\w+\\b', tokenizer=None, use_idf=True,
        vocabulary=None)

t.set_params(smooth_idf=False)
Out[25]: 
TfidfVectorizer(analyzer=u'word', binary=True, charset=None,
        charset_error=None, decode_error=u'strict',
        dtype=<type 'numpy.int64'>, encoding=u'utf-8', input=u'content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(1, 1), norm=u'l2', preprocessor=None,
        smooth_idf=False, stop_words=None, strip_accents=None,
        sublinear_tf=False, token_pattern=u'(?u)\\b\\w\\w+\\b',
        tokenizer=None, use_idf=True, vocabulary=None)

d = {"stop_words":"english"}

t.set_params(**d)
Out[27]: 
TfidfVectorizer(analyzer=u'word', binary=True, charset=None,
        charset_error=None, decode_error=u'strict',
        dtype=<type 'numpy.int64'>, encoding=u'utf-8', input=u'content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(1, 1), norm=u'l2', preprocessor=None,
        smooth_idf=False, stop_words='english', strip_accents=None,
        sublinear_tf=False, token_pattern=u'(?u)\\b\\w\\w+\\b',
        tokenizer=None, use_idf=True, vocabulary=None)

此外,source显示.set_params()循环播放您提供的参数,其余部分不受影响:

def set_params(self, **params):
    """Set the parameters of this estimator.
    The method works on simple estimators as well as on nested objects
    (such as pipelines). The former have parameters of the form
    ``<component>__<parameter>`` so that it's possible to update each
    component of a nested object.
    Returns
    -------
    self
    """
    if not params:
        # Simple optimisation to gain speed (inspect is slow)
        return self
    valid_params = self.get_params(deep=True)
    for key, value in six.iteritems(params):
        split = key.split('__', 1)
        if len(split) > 1:
            # nested objects case
            name, sub_name = split
            if name not in valid_params:
                raise ValueError('Invalid parameter %s for estimator %s. '
                                 'Check the list of available parameters '
                                 'with `estimator.get_params().keys()`.' %
                                 (name, self))
            sub_object = valid_params[name]
            sub_object.set_params(**{sub_name: value})
        else:
            # simple objects case
            if key not in valid_params:
                raise ValueError('Invalid parameter %s for estimator %s. '
                                 'Check the list of available parameters '
                                 'with `estimator.get_params().keys()`.' %
                                 (key, self.__class__.__name__))
            setattr(self, key, value)
    return self