在GridsearchCV中使用fit_params时出错,并且某些参数是单例

时间:2019-12-11 17:38:20

标签: python scikit-learn xgboost

我想在GridSearchCV上使用XGBoost的早期停止功能,但是在升级某些模块之前,我会遇到一些有趣的错误,这些错误过去一直有效。当前版本:

  • xgboost:0.90
  • scikit_learn:0.22
  • python:3.6.5

可复制的示例:

from sklearn.datasets import load_iris
from xgboost.sklearn import XGBRegressor
from sklearn.model_selection import GridSearchCV, train_test_split

iris = load_iris()

x_train, x_validate, y_train, y_validate = train_test_split(
    iris['data'], 
    iris['target'], 
    random_state=7, 
    train_size=0.75
)

model = XGBRegressor()

grid_params = {
    'max_depth': [1, 2, 3, 4, 5],
    'colsample_bytree': [0.6, 0.7, 0.8, 0.9, 1.0],
    'subsample': [0.7, 0.8, 0.9, 1.0],    
}

grid = GridSearchCV(
    model,
    cv=5,
    n_jobs=10,
    param_grid = grid_params,
    verbose=3,
    refit=True
)

fit_params = {
    'verbose': False,
    'early_stopping_rounds': 10,        
    'eval_set': [(
        x_validate,
        y_validate
    )],
}

grid.fit(
    x_train,
    y_train,
    **fit_params
)

产生的错误消息:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-1-c17691f9afb9> in <module>
     41     x_train,
     42     y_train,
---> 43     **fit_params
     44 )

/opt/conda/lib/python3.6/site-packages/sklearn/model_selection/_search.py in fit(self, X, y, groups, **fit_params)
    650         X, y, groups = indexable(X, y, groups)
    651         # make sure fit_params are sliceable
--> 652         fit_params_values = indexable(*fit_params.values())
    653         fit_params = dict(zip(fit_params.keys(), fit_params_values))
    654 

/opt/conda/lib/python3.6/site-packages/sklearn/utils/validation.py in indexable(*iterables)
    235         else:
    236             result.append(np.array(X))
--> 237     check_consistent_length(*result)
    238     return result
    239 

/opt/conda/lib/python3.6/site-packages/sklearn/utils/validation.py in check_consistent_length(*arrays)
    206     """
    207 
--> 208     lengths = [_num_samples(X) for X in arrays if X is not None]
    209     uniques = np.unique(lengths)
    210     if len(uniques) > 1:

/opt/conda/lib/python3.6/site-packages/sklearn/utils/validation.py in <listcomp>(.0)
    206     """
    207 
--> 208     lengths = [_num_samples(X) for X in arrays if X is not None]
    209     uniques = np.unique(lengths)
    210     if len(uniques) > 1:

/opt/conda/lib/python3.6/site-packages/sklearn/utils/validation.py in _num_samples(x)
    150         if len(x.shape) == 0:
    151             raise TypeError("Singleton array %r cannot be considered"
--> 152                             " a valid collection." % x)
    153         # Check that shape is returning an integer or default to len
    154         # Dask dataframes may not return numeric shape[0] value

TypeError: Singleton array array(False) cannot be considered a valid collection.

如果我略微调整代码以使fit_params看起来像这样,则它可以正常运行,但不再应用提早停止操作:

fit_params = {
    'eval_set': [(
        x_validate,
        y_validate
    )],
}

模型的fit_params值有效-我已经通过运行

对此进行了检查
model.fit(
    x_train,
    y_train,
    **fit_params
)

这可以按预期工作,但是仅适合模型的一个实例,而不是运行网格搜索。

我的网格搜索发生了什么,这有解决方法吗?

更新:

如果我更改它以使fit_params看起来像这样并设置refit = False,则网格搜索将运行而没有错误,但是我不确定是否已正确应用了提前停止:

fit_params = {
    'verbose': [False],
    'early_stopping_rounds': [10],        
    'eval_set': [(
        x_validate,
        y_validate
    )],
}

1 个答案:

答案 0 :(得分:1)

这似乎是scikit-learn 0.22的已知问题:

https://github.com/scikit-learn/scikit-learn/issues/15805

目前,这里有一种解决方法:

https://github.com/scikit-learn/scikit-learn/issues/15805#issuecomment-562927893

cia05rf commented 3 days ago
Thanks, i think this is best sorted on the lightGBM git so i'll raise an issue there.

For anyone who does come across this i have found a work around by making the below changes to sklearn/model_selection/_search.py -> fit.

Currently at line 651

# make sure fit_params are sliceable
- fit_params_values = indexable(*fit_params.values())
- fit_params = dict(zip(fit_params.keys(), fit_params_values))
+# fit_params_values = indexable(*fit_params.values())
+# fit_params = dict(zip(fit_params.keys(), fit_params_values))