如何使用Spark 2.1.1获得scala 2.10中RF的功能重要性?

时间:2017-09-27 20:25:08

标签: scala apache-spark machine-learning

我试图从Spark MLib中的随机森林回归器中获取特征重要性。问题是我使用Pipeline对象进行训练,而我不知道如何将此类对象投射到RandomForestRegressorModel以获取featureImportance

我的代码中有趣的部分是以下

val rf = new RandomForestRegressor().
        setLabelCol( "label" ).
        setFeaturesCol( "features" ).
        setNumTrees( numTrees ).
        setFeatureSubsetStrategy( featureSubsetStrategy ).
        setImpurity( impurity ).
        setMaxDepth( maxDepth ).
        setMaxBins( maxBins ).
        setMaxMemoryInMB( maxMemoryInMB )
val pipeline = new Pipeline().setStages(Array(rf))
var model = pipeline.fit( trainingDataCached )
// GET FEATURE IMPORTANCE
val featImp = model.featureImportance

我错过了什么?

谢谢。

修改

这可能是正确的答案吗?

val featImp = model
              .asInstanceOf[RandomForestRegressionModel]
              .featureImportances

2 个答案:

答案 0 :(得分:0)

  

这可能是正确的答案吗?

几乎。

model
  .stages
  .head
  .asInstanceOf[RandomForestRegressionModel]
  .featureImportances

答案 1 :(得分:0)

val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] println("学习到的分类森林模型:\n" + rfModel.toDebugString)

https://spark.apache.org/docs/1.5.2/ml-ensembles.html