随机森林分析

时间:2016-12-05 21:23:03

标签: scala apache-spark spark-dataframe apache-spark-mllib

我有一个Spark(1.5.2)DataFrame和一个训练有素的RandomForestClassificationModel。我可以轻松地fit数据并获得预测,但我想更深入地分析哪些边缘值是每个二进制分类场景中最常见的参与者。

在过去,我做过与RDD类似的事情,通过自己计算预测来跟踪功能使用情况。在下面的代码中,我跟踪用于计算预测的特征列表。在这方面,DataFrame似乎并不像RDD那么简单。

def predict(node:Node, features: Vector, path_in:Array[Int]) : (Double,Double,Array[Int]) = 
{
    if (node.isLeaf) 
    {
        (node.predict.predict,node.predict.prob,path_in)
    } 
    else
    {
        //track our path through the tree
        val path = path_in :+ node.split.get.feature

        if (node.split.get.featureType == FeatureType.Continuous) 
        {
            if (features(node.split.get.feature) <= node.split.get.threshold) 
            {
                predict(node.leftNode.get, features, path)
            } 
            else 
            {
                predict(node.rightNode.get, features, path)
            }
        } 
        else 
        {
            if (node.split.get.categories.contains(features(node.split.get.feature))) 
            {
                predict(node.leftNode.get, features, path)
            }
            else 
            {
                predict(node.rightNode.get, features, path)
            }
        }
    }
}

我想做一些类似于此代码的操作,但是对于每个特征向量,我返回所有要素/边值对的列表。请注意,在我的数据集中,所有功能都是分类的,并且在构建林时适当地使用了bin设置。

1 个答案:

答案 0 :(得分:0)

我最终构建了一个自定义udf来执行此操作:

//Base Prediction method. Accepts a Random Forest Model and a Feature Vector
//  Returns an Array of predictions, one per tree, the impurity, the feature used on the final edge, and the feature value.
def predicForest(m:RandomForestClassificationModel, point: Vector) : (Double, Array[(Double,Double,(Int,Double))])={
    val results = m.trees.map(t=> predict(t.rootNode,point))

    (results.map(x=> x._1).sum/results.count(x=> true), results)
}

def predict(node:Node, features: Vector) : (Double,Double,(Int,Double)) = {
    if (node.isInstanceOf[InternalNode]){
      //track our path through the tree
      val internalNode = node.asInstanceOf[InternalNode]
      if (internalNode.split.isInstanceOf[CategoricalSplit]) {
        val split = internalNode.split.asInstanceOf[CategoricalSplit]
        val featureValue = features(split.featureIndex)
        if (split.leftCategories.contains(featureValue)) {
          if (internalNode.leftChild.isInstanceOf[LeafNode]) {
            (node.prediction,node.impurity,(internalNode.split.featureIndex, featureValue))
          } else
            predict(internalNode.leftChild, features)
        } else {
          if (internalNode.rightChild.isInstanceOf[LeafNode]) {
            (node.prediction,node.impurity,(internalNode.split.featureIndex, featureValue))
          } else
            predict(internalNode.rightChild, features)
        }
      } else {
        //If we run into an unimplemented type we just return
        (node.prediction,node.impurity,(-1,-1))
      }
    } else {
      //If we run into an unimplemented type we just return
      (node.prediction,node.impurity,(-1,-1))
    }
}

val rfModel = yourInstanceOfRandomForestClassificationModel

//This custom UDF executes the Random Forest Classification in a trackable way
def treeAnalyzer(m:RandomForestClassificationModel) = udf((x:Vector) =>
  predicForest(m,x))

//Execute the UDF, this will execute the Random Forest classification on each row and store the results from each tree in a new column named `prediction`
val df3 = testData.withColumn("prediction", treeAnalyzer(rfModel)(testData("indexedFeatures")))