当使用来自sklearn的cross_validate的多个计分器时,每个计分器的输出都相同

时间:2019-12-25 22:31:09

标签: scikit-learn cross-validation text-classification

我运行此代码来测试20Newsgroups数据集上不同分类器的文本分类。 (这里以Complement Naive Bayes为例)。我想用cross_validate函数向我打印精度,精度,召回率和f1度量值,但是我获得的精度,精度,召回率和f1度量值完全相同。这一定是错误的,还是我误解了评分的行为?

from sklearn import datasets
from sklearn.feature_extraction.text import TfidfVectorizer

from sklearn.metrics import make_scorer, accuracy_score, precision_score, recall_score, f1_score

from sklearn.naive_bayes import ComplementNB

from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_validate

#Defining the Dictionary of wanted measures
evaluation_scores = {'accuracy' : make_scorer(accuracy_score), 
                     'precision' : make_scorer(precision_score, average='micro'),
                     'recall' : make_scorer(recall_score, average='micro'),
                     'f1_score' : make_scorer(f1_score, average='micro')}


#Loading the Dataset 
categories= ['alt.atheism', 'soc.religion.christian', 'comp.graphics', 'sci.med']
twenty_all = datasets.fetch_20newsgroups(subset='all', categories=categories,
                                         remove=('headers', 'footers', 'quotes'),
                                           shuffle=True)
#Preparing the features
tfidf_vect = TfidfVectorizer(stop_words='english')
x_full_tfidf = tfidf_vect.fit_transform(twenty_all.data)

X_train_regular, X_test_regular, y_train_regular, y_test_regular = train_test_split(x_full_tfidf, twenty_all.target)

classifier = ComplementNB()

#Cross-Validation    
scores = cross_validate(classifier, X_train_regular, y_train_regular, cv=5, scoring=evaluation_scores, return_train_score=False, n_jobs=-1)

#Printing measures
for key, values in scores.items():
            print(key,'\tmean ', values.mean(), '\t+/-', values.std(), ' std')

输出:

fit_time        mean  0.02199974060058594       +/- 0.015006788218343144  std
score_time      mean  0.007401227951049805      +/- 0.0007992988464814343  std
test_accuracy   mean  0.8559753091291809        +/- 0.021032351448968582  std
test_precision  mean  0.8559753091291809        +/- 0.021032351448968582  std
test_recall     mean  0.8559753091291809        +/- 0.021032351448968582  std
test_f1_score   mean  0.8559753091291809        +/- 0.02103235144896859  std

您能帮我发现错误吗?还是这样,我错过了其他事情吗?我怀疑得分手的均值=“微”属性。

谢谢。

0 个答案:

没有答案