XGBoost不会生成num_round参数中指定的树数

时间:2017-08-16 07:27:28

标签: apache-spark xgboost

这不是一个错误,而是一个需要理解的问题。当我从Booster对象调用getModelDump时,我没有得到“num_round”参数中的树数。我在想,如果“num_round”为100,那么XGBoost会依次生成100棵树,当我调用getModelDump时,我会看到所有这些树。我确信背后有合理的理由,或者我的知识是错误的。你能解释一下这种情况吗?

val paramMap = List(
      "eta" -> 0.1, "max_depth" -> 7, "objective" -> "binary:logistic", "num_round" ->100,
      "eval_metric" -> "auc", "nworkers" -> 8).toMap
    val xgboostEstimator = new XGBoostEstimator(paramMap)
//TrainModel is another set of standard Spark features like StringIndexer, OnehotEncoding and VectorAssembler
    val pipelineXGBoost = new Pipeline().setStages(Array(trainModel, xgboostEstimator))
    val cvModel = pipelineXGBoost.fit(train)
//Below call generates only 2 tree instead of 100 as num_round is 100!!!
    println(cvModel.stages(1).asInstanceOf[XGBoostClassificationModel].booster.getModelDump()(0))

Github链接到问题https://github.com/dmlc/xgboost/issues/2610

使用scala 2.11

的版本如下
  "ml.dmlc" % "xgboost4j" % "0.7",
  "ml.dmlc" % "xgboost4j-spark" % "0.7",
  "org.apache.spark" %% "spark-core" % "2.2.0",
  "org.apache.spark" %% "spark-sql" % "2.2.0",
  "org.apache.spark" %% "spark-graphx" % "2.2.0",
  "org.apache.spark" %% "spark-mllib" % "2.2.0",

1 个答案:

答案 0 :(得分:0)

我没有从getModelDump的结果中得到(0 .. num_round)。每个索引都对应另一棵树。

在链接https://github.com/dmlc/xgboost/issues/2610

中回答