我正在尝试使用20newsgroups数据库调整Multinomial Naive Bayes的alpha参数。到目前为止,这是我的代码:
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report
from sklearn.model_selection import GridSearchCV
import numpy as np
# Divide dataset
dataset_train = fetch_20newsgroups(subset='train', shuffle=True)
dataset_test = fetch_20newsgroups(subset='test', shuffle=True)
text_clf = Pipeline([('vect', CountVectorizer()), ('tfidf', TfidfTransformer(sublinear_tf=True)), ('clf',
MultinomialNB())])
param_grid = {'tfidf__use_idf': (True, False),
'clf__alpha' : np.linspace(0.001, 1, 100)}
grid_search = GridSearchCV(text_clf, param_grid=param_grid, scoring='precision', cv = None)
# Training
text_clf = grid_search.fit(dataset_train.data,dataset_train.target, average=None)
#prediction
predicted = text_clf.predict(dataset_test.data)
print("NB Accuracy:", 100*np.mean(predicted == dataset_test.target), '%')
print(classification_report(dataset_test.target, predicted, target_names=dataset_train.target_names))
print("Best estimator for alpha in order to get precision ", grid_search.best_estimator_)
问题是我收到以下错误:
runfile('C:/Users/omarl/Downloads/new_NB.py', wdir='C:/Users/omarl/Downloads')
Traceback (most recent call last):
File "<ipython-input-12-d478372ef22a>", line 1, in <module>
runfile('C:/Users/omarl/Downloads/new_NB.py', wdir='C:/Users/omarl/Downloads')
File "C:\Users\omarl\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 705, in runfile
execfile(filename, namespace)
File "C:\Users\omarl\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 102, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File "C:/Users/omarl/Downloads/new_NB.py", line 28, in <module>
text_clf = grid_search.fit(dataset_train.data,dataset_train.target, average=None)
File "C:\Users\omarl\Anaconda3\lib\site-packages\sklearn\model_selection\_search.py", line 639, in fit
cv.split(X, y, groups)))
File "C:\Users\omarl\Anaconda3\lib\site-packages\sklearn\externals\joblib\parallel.py", line 779, in __call__
while self.dispatch_one_batch(iterator):
File "C:\Users\omarl\Anaconda3\lib\site-packages\sklearn\externals\joblib\parallel.py", line 625, in dispatch_one_batch
self._dispatch(tasks)
File "C:\Users\omarl\Anaconda3\lib\site-packages\sklearn\externals\joblib\parallel.py", line 588, in _dispatch
job = self._backend.apply_async(batch, callback=cb)
File "C:\Users\omarl\Anaconda3\lib\site-packages\sklearn\externals\joblib\_parallel_backends.py", line 111, in apply_async
result = ImmediateResult(func)
File "C:\Users\omarl\Anaconda3\lib\site-packages\sklearn\externals\joblib\_parallel_backends.py", line 332, in __init__
self.results = batch()
File "C:\Users\omarl\Anaconda3\lib\site-packages\sklearn\externals\joblib\parallel.py", line 131, in __call__
return [func(*args, **kwargs) for func, args, kwargs in self.items]
File "C:\Users\omarl\Anaconda3\lib\site-packages\sklearn\externals\joblib\parallel.py", line 131, in <listcomp>
return [func(*args, **kwargs) for func, args, kwargs in self.items]
File "C:\Users\omarl\Anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 458, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "C:\Users\omarl\Anaconda3\lib\site-packages\sklearn\pipeline.py", line 248, in fit
Xt, fit_params = self._fit(X, y, **fit_params)
File "C:\Users\omarl\Anaconda3\lib\site-packages\sklearn\pipeline.py", line 197, in _fit
step, param = pname.split('__', 1)
ValueError: not enough values to unpack (expected 2, got 1)
我不知道为什么会发生这种情况,因为到目前为止我查看过的代码应该可行。我也在Scikit网站上搜索过但我没有找到任何东西。感谢。
答案 0 :(得分:1)
在这一行:
text_clf = grid_search.fit(dataset_train.data,dataset_train.target, average=None)
average=None
被解释为fit_param,这不是你想要的。
平均删除此项,您将收到此错误。
ValueError: Target is multiclass but average='binary'. Please choose another average setting.
这是因为多类设置中未定义精度。如果您将评分参数更改为&#39;准确度&#39;,则代码可以正常工作。