spark MLLIB中的节点统计

时间:2019-06-27 14:25:14

标签: scala apache-spark apache-spark-mllib

我正在遵循文档来训练决策树回归器或我的数据(https://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier) 最终的输出看起来像这样:

Learned regression tree model:
 DecisionTreeRegressionModel (uid=dtr_ba1638819fb1) of depth 5 with 63 nodes
  If (feature 41 <= 0.0)
   If (feature 35 <= 5.0)
    If (feature 42 <= 60.0)
     If (feature 0 <= 3740051.0)
      If (feature 23 <= 2.0)
       Predict: 1.2777917018136313E-4
      Else (feature 23 > 2.0)
       Predict: 3.5522811772381764E-4
     Else (feature 0 > 3740051.0)
      If (feature 32 <= 1.0)
       Predict: 1.0701321366121918E-4
      Else (feature 32 > 1.0)
       Predict: 1.2083112677997485E-4
    Else (feature 42 > 60.0)
etc.

这一切都很好,但我想在每个节点中都有一些统计信息(至少是示例数量)。 说出对应于以下节点的示例数:

(feature 41 <= 0.0) and (feature 35 <= 5.0))

如scikit-learn中所述。我在API中找不到任何可以使我接近这一点的东西。帮助非常感谢!

谢谢

0 个答案:

没有答案