如何从Spark RandomForestRegressionModel获取maxDepth

时间:2017-01-17 06:08:53

标签: apache-spark pyspark apache-spark-mllib

在Spark(2.1.0)中,我使用CrossValidator来训练RandomForestRegressorParamGridBuilder使用maxDepthnumTreesparamGrid = ParamGridBuilder() \ .addGrid(rf.maxDepth, [2, 4, 6, 8, 10]) \ .addGrid(rf.numTrees, [10, 20, 40, 50]) \ .build()

regressor = cvModel.bestModel.stages[len(cvModel.bestModel.stages) - 1]

print(regressor.getNumTrees)

训练结束后,我可以得到最好的树木数量:

regressor.trees[0].depth

但我无法弄清楚如何获得最佳的maxDepth。我已经阅读了documentation,但我看不到我错过的内容。

我注意到我可以遍历所有树并找到每个树的深度,例如

rep = []
for v in val:
    st = ''.join(d[ch] for ch in v[0])
    rep.append(st)
new_val= ' '.join(rep)

这似乎让我错过了一些东西。

2 个答案:

答案 0 :(得分:2)

不幸的是,Spark 2.3之前的PySpark RandomForestRegressionModel与Scala版本不同,它不存储上游Estimator Params,但您应该能够直接从JVM对象中检索它。用一个简单的猴子补丁:

from pyspark.ml.regression import RandomForestRegressionModel

RandomForestRegressionModel.getMaxDepth = (
    lambda self: self._java_obj.getMaxDepth()
)
你可以:

cvModel.bestModel.stages[-1].getMaxDepth()

答案 1 :(得分:1)

更简单,只需致电

    cvModel.bestModel.stages[-1]._java_obj.getMaxDepth()

正如@ user6910411所解释的那样,您将获得bestModel,调用此模型的JVM对象并使用JVM对象中的getMaxDepth()提取参数。 其他参数类似。