Scikit-Learn使用自定义CountVectorizer标记生成器进行网格搜索

时间:2014-06-01 14:14:02

标签: python multithreading scikit-learn nltk

我目前正在学习更多关于scikit learn和nltk的知识,我正在构建一个文本分类器。

我不是python专家,但我正在学习(我有各种其他编程语言的背景)

现在,我有这个代码,它为我的分类管道执行网格搜索以获得最佳参数。我省略了一些参数来更好地说明问题。

以下代码为例,我将优化LinearSVC分类器的“C”参数。

pipeline = Pipeline([
    ('vect', CountVectorizer()),
    ('tfidf', TfidfTransformer()),
    ('clf', LinearSVC()),
])

parameters = {
    'clf__C': (5, 3, 1),
}

if __name__ == "__main__":
    # multiprocessing requires the fork to happen in a __main__ protected
    # block

    # find the best parameters for both the feature extraction and the
    # classifier
    grid_search = GridSearchCV(pipeline, parameters, n_jobs=-1, verbose=1)

    print("Performing grid search...")
    print("pipeline:", [name for name, _ in pipeline.steps])
    print("parameters:")
    pprint(parameters)
    t0 = time()
    grid_search.fit(data.Phrase, data.Sentiment)
    print("done in %0.3fs" % (time() - t0))
    print()

    print("Best score: %0.9f" % grid_search.best_score_)
    print("Best parameters set:")
    best_parameters = grid_search.best_estimator_.get_params()
    for param_name in sorted(parameters.keys()):
        print("\t%s: %r" % (param_name, best_parameters[param_name]))

此代码工作正常,但如果我想使用自定义标记器/分析器/ ...网格搜索将失败,可能是因为多线程。

如您所知,scikit learn需要将tokenizer设置为可调用类(如果是自定义的)。以下是文档中的LemmaTokenizer:

class LemmaTokenizer(object):
 def __init__(self):
     self.wnl = WordNetLemmatizer()
 def __call__(self, doc):
     return [self.wnl.lemmatize(t) for t in wordpunct_tokenize(doc)]

现在,只使用此类的管道就可以通过将管道更改为:

来完美地运行
pipeline = Pipeline([
    ('vect', CountVectorizer(tokenizer=LemmaTokenizer())),
    ('tfidf', TfidfTransformer()),
    ('clf', LinearSVC()),
])

但是,如果我想使用此自定义标记生成器执行相同的网格搜索,它会失败并给出以下输出:

Performing grid search...
pipeline: ['vect', 'tfidf', 'clf']
parameters:
{'tfidf__norm': ('l1', 'l2')}
Fitting 3 folds for each of 2 candidates, totalling 6 fits
Exception in thread Thread-6:
Traceback (most recent call last):
  File "C:\Anaconda\lib\threading.py", line 810, in __bootstrap_inner
    self.run()
  File "C:\Anaconda\lib\threading.py", line 763, in run
    self.__target(*self.__args, **self.__kwargs)
  File "C:\Anaconda\lib\multiprocessing\pool.py", line 342, in _handle_tasks
    put(task)
PicklingError: Can't pickle <type 'instancemethod'>: attribute lookup __builtin__.instancemethod failed

我猜是类实例在某个地方丢失了?

我正在使用最新的稳定scikit-learn 1.4

0 个答案:

没有答案