如何针对多类问题将GridSearchCV用于lightgbm分类器? (蟒蛇)

时间:2019-11-06 11:00:47

标签: python machine-learning lightgbm

我正在执行以下操作:

from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, cross_val_score, train_test_split    
import lightgbm as lgb



param_test ={
                 'learning_rate' : [0.01, 0.02, 0.03, 0.04, 0.05, 0.08, 0.1, 0.2, 0.3, 0.4]
                }


clf = lgb.LGBMClassifier(boosting_type='gbdt',\
                   num_leaves=31, \
                   max_depth=-1, \
                   n_estimators=100, \
                   subsample_for_bin=200000, \
                   objective='multiclass', \
                   class_weight=balanced, \
                   min_split_gain=0.0, \
                   min_child_weight=0.001, \
                   min_child_samples=20, \
                   subsample=1.0, \
                   subsample_freq=0, \
                   colsample_bytree=1.0, \
                   reg_alpha=0.0, \
                   reg_lambda=0.0, \
                   random_state=None,\
                   n_jobs=-1,\
                   silent=True, \
                   importance_type='split'
                  )


gs = GridSearchCV(
                    estimator=clf,
                    param_grid = param_test, 
                    scoring='roc_auc',
                    cv=3
                  )

gs.fit(X_train, y_train_lbl["target_encoded"].values)

我收到以下错误:

    /home/cdsw/.local/lib/python3.6/site-packages/sklearn/model_selection/_validation.py in _score(estimator, X_test, y_test, scorer, is_multimetric)
    597     """
    598     if is_multimetric:
--> 599         return _multimetric_score(estimator, X_test, y_test, scorer)
    600     else:
    601         if y_test is None:

/home/cdsw/.local/lib/python3.6/site-packages/sklearn/model_selection/_validation.py in _multimetric_score(estimator, X_test, y_test, scorers)
    627             score = scorer(estimator, X_test)
    628         else:
--> 629             score = scorer(estimator, X_test, y_test)
    630 
    631         if hasattr(score, 'item'):

/home/cdsw/.local/lib/python3.6/site-packages/sklearn/metrics/scorer.py in __call__(self, clf, X, y, sample_weight)
    173         y_type = type_of_target(y)
    174         if y_type not in ("binary", "multilabel-indicator"):
--> 175             raise ValueError("{0} format is not supported".format(y_type))
    176 
    177         if is_regressor(clf):

**ValueError: multiclass format is not supported**

因此,困扰我的是不支持多类的值错误。我在这里缺少一些基本知识吗?我使用auc作为指标。应该是multi_logloss吗?我也尝试没有结果。

有人可以帮我吗?

0 个答案:

没有答案