GridSearchCV是否存储所有参数组合的所有分数?

时间:2015-12-14 18:59:23

标签: python machine-learning scikit-learn artificial-intelligence

GridSearchCV使用“评分”来选择最佳估算器。训练GridSearchCV后,我希望看到每个组合的得分。 GridSearchCV是否存储每个参数组合的所有分数?如果它如何获得分数?感谢。

这是我在另一篇文章中使用的示例代码。

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.grid_search import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.naive_bayes import MultinomialNB

X_train = ['qwe rtyuiop', 'asd fghj kl', 'zx cv bnm', 'qw erty ui op', 'as df ghj kl', 'zxc vb nm', 'qwe rt yu iop', 'asdfg hj kl', 'zx cvb nm',
          'qwe rt yui op', 'asd fghj kl', 'zx cvb nm', 'qwer tyui op', 'asd fg hjk l', 'zx cv b nm', 'qw ert yu iop', 'as df gh jkl', 'zx cvb nm',
           'qwe rty uiop', 'asd fghj kl', 'zx cvbnm', 'qw erty ui op', 'as df ghj kl', 'zxc vb nm', 'qwe rtyu iop', 'as dfg hj kl', 'zx cvb nm',
          'qwe rt yui op', 'asd fg hj kl', 'zx cvb nm', 'qwer tyuiop', 'asd fghjk l', 'zx cv b nm', 'qw ert yu iop', 'as df gh jkl', 'zx cvb nm']    

y_train = ['1', '2', '3', '1', '1', '3', '1', '2', '3',
          '1', '2', '3', '1', '4', '1', '2', '2', '4', 
          '1', '2', '3', '1', '1', '3', '1', '2', '3',
          '1', '2', '3', '1', '4', '1', '2', '2', '4']    


parameters = {  
                'clf__alpha': (1e-1, 1e-2),
                 'vect__ngram_range': [(1,2),(1,3)],
                 'vect__max_df': (0.9, 0.98)
            }

text_clf_Pipline_MultinomialNB = Pipeline([('vect', CountVectorizer()),
                                           ('tfidf', TfidfTransformer()),
                                           ('clf', MultinomialNB()),                     
                                          ])
gs_clf = GridSearchCV(text_clf_Pipline_MultinomialNB, parameters, n_jobs=-1)   

gs_classifier = gs_clf.fit(X_train, y_train)

2 个答案:

答案 0 :(得分:16)

是的,完全按照docs

中的说明
  

grid_scores_:已命名的元组列表

     

包含所有参数的分数   param_grid中的组合。每个条目对应一个参数   设置。每个命名元组都具有以下属性:

     
      
  • parameters,参数设置的词典
  •   
  • mean_validation_score,交叉验证折叠的平均分数
  •   
  • cv_validation_scores,每个折叠的得分列表
  •   

答案 1 :(得分:1)

allscores=model.cv_results_['mean_test_score']
print(allscores)