sklearn中GridSearchCV中的得分手问题

时间:2015-02-20 16:02:59

标签: python scikit-learn

我正在尝试在RF分类器上执行网格搜索,其中评分函数是sklearn.metrics模块中的precision_score。这是代码。

from sklearn.metrics import precision_score

param_grid = {'n_estimators': [51, 101, 201, 301, 501],
              'max_depth': [3, 5, 10, None],
              'min_samples_split': [2, 5, 10],
              'criterion': ['gini', 'entropy'],
              'bootstrap': [True, False]}

def fit_gridCV_RFclassifier(param_grid):
    from sklearn.ensemble import RandomForestClassifier
    rf = RandomForestClassifier()
    clf = GridSearchCV(estimator=rf, param_grid=param_grid,
                       cv=5, scoring=precision_score,
                       refit=True)
    clf.fit(train_X, train_y)
    return clf

gridsearch_rf = fit_gridCV_RFclassifier(param_grid)

运行该功能时,出现以下错误

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-34-6f91362a017c> in <module>()
----> 1 gridsearch_rf = fit_gridCV_RFclassifier(param_grid)

<ipython-input-33-974d026d5dc8> in fit_gridCV_RFclassifier(param_grid)
     11                        scoring=precision_score,
     12                        cv=5, refit=True)
---> 13     clf.fit(train_X, train_y)
     14     return clf

/anaconda/lib/python2.7/site-packages/sklearn/grid_search.pyc in fit(self, X, y)
    594 
    595         """
--> 596         return self._fit(X, y, ParameterGrid(self.param_grid))
    597 
    598 

/anaconda/lib/python2.7/site-packages/sklearn/grid_search.pyc in _fit(self, X, y, parameter_iterable)
    376                                     train, test, self.verbose, parameters,
    377                                     self.fit_params, return_parameters=True)
--> 378             for parameters in parameter_iterable
    379             for train, test in cv)
    380 

/anaconda/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc in __call__(self, iterable)
    651             self._iterating = True
    652             for function, args, kwargs in iterable:
--> 653                 self.dispatch(function, args, kwargs)
    654 
    655             if pre_dispatch == "all" or n_jobs == 1:

/anaconda/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc in dispatch(self, func, args, kwargs)
    398         """
    399         if self._pool is None:
--> 400             job = ImmediateApply(func, args, kwargs)
    401             index = len(self._jobs)
    402             if not _verbosity_filter(index, self.verbose):

/anaconda/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc in __init__(self, func, args, kwargs)
    136         # Don't delay the application, to avoid keeping the input
    137         # arguments in memory
--> 138         self.results = func(*args, **kwargs)
    139 
    140     def get(self):

/anaconda/lib/python2.7/site-packages/sklearn/cross_validation.pyc in _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score, return_parameters)
   1238     else:
   1239         estimator.fit(X_train, y_train, **fit_params)
-> 1240     test_score = _score(estimator, X_test, y_test, scorer)
   1241     if return_train_score:
   1242         train_score = _score(estimator, X_train, y_train, scorer)

/anaconda/lib/python2.7/site-packages/sklearn/cross_validation.pyc in _score(estimator, X_test, y_test, scorer)
   1294         score = scorer(estimator, X_test)
   1295     else:
-> 1296         score = scorer(estimator, X_test, y_test)
   1297     if not isinstance(score, numbers.Number):
   1298         raise ValueError("scoring must return a number, got %s (%s) instead."

/anaconda/lib/python2.7/site-packages/sklearn/metrics/metrics.pyc in precision_score(y_true, y_pred, labels, pos_label, average, sample_weight)
   1883                                                  average=average,
   1884                                                  warn_for=('precision',),
-> 1885                                                  sample_weight=sample_weight)
   1886     return p
   1887 

/anaconda/lib/python2.7/site-packages/sklearn/metrics/metrics.pyc in precision_recall_fscore_support(y_true, y_pred, beta, labels, pos_label, average, warn_for, sample_weight)
   1667         raise ValueError("beta should be >0 in the F-beta score")
   1668 
-> 1669     y_type, y_true, y_pred = _check_clf_targets(y_true, y_pred)
   1670 
   1671     label_order = labels  # save this for later

/anaconda/lib/python2.7/site-packages/sklearn/metrics/metrics.pyc in _check_clf_targets(y_true, y_pred)
    107     y_pred : array or indicator matrix
    108     """
--> 109     y_true, y_pred = check_arrays(y_true, y_pred, allow_lists=True)
    110     type_true = type_of_target(y_true)
    111     type_pred = type_of_target(y_pred)

/anaconda/lib/python2.7/site-packages/sklearn/utils/validation.pyc in check_arrays(*arrays, **options)
    252         if size != n_samples:
    253             raise ValueError("Found array with dim %d. Expected %d"
--> 254                              % (size, n_samples))
    255 
    256         if not allow_lists or hasattr(array, "shape"):

ValueError: Found array with dim 317760. Expected 51

似乎错误来自评分功能。任何帮助,将不胜感激。感谢。

我的scikit-learn版本:0.15.2

1 个答案:

答案 0 :(得分:4)

“得分”参数需要(docs

评分:字符串,可调用或无,可选,默认值:无

A string (see model evaluation documentation) or a scorer callable object / function with signature scorer(estimator, X, y).

“precision_score”函数具有不同的签名。你应该做的只是给一个字符串,因为“精度”是内置指标之一(docs):

clf = GridSearchCV(estimator=rf, param_grid=param_grid,
                   cv=5, scoring="precision",
                   refit=True)