xgb.train和xgb.XGBRegressor(或xgb.XGBClassifier)有什么区别?

时间:2017-11-07 07:54:55

标签: python machine-learning scikit-learn regression xgboost

我已经知道“xgboost.XGBRegressor是XGBoost的Scikit-Learn Wrapper接口。”

但他们还有其他区别吗?

3 个答案:

答案 0 :(得分:15)

xgboost.train是通过渐变增强方法训练模型的低级API。

xgboost.XGBRegressorxgboost.XGBClassifier是包装器( Scikit-Learn-like-wrappers ,正如他们所说的那样),它准备DMatrix并传入相应的目标函数和参数。最后,fit电话简单归结为:

self._Booster = train(params, dmatrix,
                      self.n_estimators, evals=evals,
                      early_stopping_rounds=early_stopping_rounds,
                      evals_result=evals_result, obj=obj, feval=feval,
                      verbose_eval=verbose)

这意味着XGBRegressorXGBClassifier可以通过基础xgboost.train功能实现一切。另一方面显然不是这样,例如,xgboost.train API不支持XGBModel的一些有用参数。明显差异列表包括:

  • xgboost.train允许设置在每次迭代结束时应用的callbacks
  • xgboost.train允许通过xgb_model参数继续培训。
  • xgboost.train不仅允许缩小eval函数,还允许最大化。

答案 1 :(得分:6)

@Maxim,从xgboost 0.90(或更早)开始,这些差异不再存在于xgboost.XGBClassifier.fit中:

  • import { BrowserRouter as Router, Route, Switch } from 'react-router-dom'; import React, { Suspense, lazy } from 'react'; import Header from './components/Header'; import Footer from './components/Footer'; const NewsList = lazy(() => import('./pages/NewsList')); const NewsItemPage = lazy(() => import('./pages/NewsItemPage')); const App = () => ( <Router> <Header /> <Suspense fallback={<div>Loading...</div>}> <Switch> <Route exact path="/" component={NewsList}/> <Route path="/news/:id" component={NewsItemPage}/> </Switch> </Suspense> <Footer /> </Router> );
  • 允许使用callbacks参数
  • 并支持相同的内置评估指标或自定义评估功能

我发现与众不同的是xgb_model,因为它必须在适合(evals_result)之后分别进行检索,并且得到的clf.evals_result()也有所不同,因为它无法利用监视列表中的评估名称(dict)。

答案 2 :(得分:0)

我认为主要区别在于训练/预测速度。

作为进一步参考,我将称为xgboost.train-'native_implementation'和XGBClassifier.fit-'sklearn_wrapper'

我已经对数据集形状(240000,348)进行了一些基准测试

适合/训练时间: sklearn_wrapper时间= 89秒 native_implementation时间= 7秒

预测时间: sklearn_wrapper = 6秒 native_implementation = 3.5毫秒

我认为这是由于sklearn_wrapper被设计为使用pandas / numpy对象作为输入,而native_implementation需要将输入数据转换为xgboost.DMatrix对象。 / p>

另外,可以使用native_implementation优化n_estimator。