拟合XGBClassifier时出现KeyError:“ base_score”

时间:2020-06-05 07:57:11

标签: python grid-search xgbclassifier

使用Gridsearch,在拟合训练数据后,我发现了最佳的超参数:

model_xgb = XGBClassifier()
n_estimators = [50, 100, 150, 200]
max_depth = [2, 4, 6, 8]
param_grid = dict(max_depth=max_depth, n_estimators=n_estimators)
kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=7)
grid_search = GridSearchCV(model_xgb, param_grid, scoring="neg_log_loss", n_jobs=-1, cv=kfold, verbose=1)
grid_result = grid_search.fit(train_X, y_train)

使用{'max_depth': 4, 'n_estimators': 50}获得最佳答案。这就是为什么我使用这些超参数创建新模型的原因:

model_xgb_tn = XGBClassifier(n_estimators=50,max_depth=4,objective='multi:softprob')

当我尝试使模型适合我的数据:model_xgb_tn.fit(train_X,y_train)时,我收到了KeyError: 'base_score'。我什至不使用超参数时也无法理解为什么会出现KeyError。

下面是错误代码:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~\Anaconda3\lib\site-packages\IPython\core\formatters.py in __call__(self, obj, include, exclude)
    968 
    969             if method is not None:
--> 970                 return method(include=include, exclude=exclude)
    971             return None
    972         else:

~\Anaconda3\lib\site-packages\sklearn\base.py in _repr_mimebundle_(self, **kwargs)
    461     def _repr_mimebundle_(self, **kwargs):
    462         """Mime bundle used by jupyter kernels to display estimator"""
--> 463         output = {"text/plain": repr(self)}
    464         if get_config()["display"] == 'diagram':
    465             output["text/html"] = estimator_html_repr(self)

~\Anaconda3\lib\site-packages\sklearn\base.py in __repr__(self, N_CHAR_MAX)
    277             n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW)
    278 
--> 279         repr_ = pp.pformat(self)
    280 
    281         # Use bruteforce ellipsis when there are a lot of non-blank characters

~\Anaconda3\lib\pprint.py in pformat(self, object)
    142     def pformat(self, object):
    143         sio = _StringIO()
--> 144         self._format(object, sio, 0, 0, {}, 0)
    145         return sio.getvalue()
    146 

~\Anaconda3\lib\pprint.py in _format(self, object, stream, indent, allowance, context, level)
    159             self._readable = False
    160             return
--> 161         rep = self._repr(object, context, level)
    162         max_width = self._width - indent - allowance
    163         if len(rep) > max_width:

~\Anaconda3\lib\pprint.py in _repr(self, object, context, level)
    391     def _repr(self, object, context, level):
    392         repr, readable, recursive = self.format(object, context.copy(),
--> 393                                                 self._depth, level)
    394         if not readable:
    395             self._readable = False

~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in format(self, object, context, maxlevels, level)
    168     def format(self, object, context, maxlevels, level):
    169         return _safe_repr(object, context, maxlevels, level,
--> 170                           changed_only=self._changed_only)
    171 
    172     def _pprint_estimator(self, object, stream, indent, allowance, context,

~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in _safe_repr(object, context, maxlevels, level, changed_only)
    412         recursive = False
    413         if changed_only:
--> 414             params = _changed_params(object)
    415         else:
    416             params = object.get_params(deep=False)

~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in _changed_params(estimator)
     96     init_params = {name: param.default for name, param in init_params.items()}
     97     for k, v in params.items():
---> 98         if (repr(v) != repr(init_params[k]) and
     99                 not (is_scalar_nan(init_params[k]) and is_scalar_nan(v))):
    100             filtered_params[k] = v

KeyError: 'base_score'

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~\Anaconda3\lib\site-packages\IPython\core\formatters.py in __call__(self, obj)
    700                 type_pprinters=self.type_printers,
    701                 deferred_pprinters=self.deferred_printers)
--> 702             printer.pretty(obj)
    703             printer.flush()
    704             return stream.getvalue()

~\Anaconda3\lib\site-packages\IPython\lib\pretty.py in pretty(self, obj)
    400                         if cls is not object \
    401                                 and callable(cls.__dict__.get('__repr__')):
--> 402                             return _repr_pprint(obj, self, cycle)
    403 
    404             return _default_pprint(obj, self, cycle)

~\Anaconda3\lib\site-packages\IPython\lib\pretty.py in _repr_pprint(obj, p, cycle)
    695     """A pprint that just redirects to the normal repr function."""
    696     # Find newlines and replace them with p.break_()
--> 697     output = repr(obj)
    698     for idx,output_line in enumerate(output.splitlines()):
    699         if idx:

~\Anaconda3\lib\site-packages\sklearn\base.py in __repr__(self, N_CHAR_MAX)
    277             n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW)
    278 
--> 279         repr_ = pp.pformat(self)
    280 
    281         # Use bruteforce ellipsis when there are a lot of non-blank characters

~\Anaconda3\lib\pprint.py in pformat(self, object)
    142     def pformat(self, object):
    143         sio = _StringIO()
--> 144         self._format(object, sio, 0, 0, {}, 0)
    145         return sio.getvalue()
    146 

~\Anaconda3\lib\pprint.py in _format(self, object, stream, indent, allowance, context, level)
    159             self._readable = False
    160             return
--> 161         rep = self._repr(object, context, level)
    162         max_width = self._width - indent - allowance
    163         if len(rep) > max_width:

~\Anaconda3\lib\pprint.py in _repr(self, object, context, level)
    391     def _repr(self, object, context, level):
    392         repr, readable, recursive = self.format(object, context.copy(),
--> 393                                                 self._depth, level)
    394         if not readable:
    395             self._readable = False

~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in format(self, object, context, maxlevels, level)
    168     def format(self, object, context, maxlevels, level):
    169         return _safe_repr(object, context, maxlevels, level,
--> 170                           changed_only=self._changed_only)
    171 
    172     def _pprint_estimator(self, object, stream, indent, allowance, context,

~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in _safe_repr(object, context, maxlevels, level, changed_only)
    412         recursive = False
    413         if changed_only:
--> 414             params = _changed_params(object)
    415         else:
    416             params = object.get_params(deep=False)

~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in _changed_params(estimator)
     96     init_params = {name: param.default for name, param in init_params.items()}
     97     for k, v in params.items():
---> 98         if (repr(v) != repr(init_params[k]) and
     99                 not (is_scalar_nan(init_params[k]) and is_scalar_nan(v))):
    100             filtered_params[k] = v

KeyError: 'base_score'

1 个答案:

答案 0 :(得分:0)

您需要提供一个基本得分参数,对于梯度增强的第一次迭代,您可以将其视为初始权重。对于回归,它是目标列的平均值,对于分类问题,它是1 /(类数)。您可以参考documentation of xgboost以获得有关此超参数的更多信息。