PysPark随机森林分类模型-提取树权重

时间:2019-10-11 11:19:10

标签: pyspark random-forest

在Spark RandomForestClassificationModel 中,每个类别的总预测概率是通过对每棵树的概率求和然后重新缩放来计算的。

在重新缩放之前,每一类的总和是在数据框中的 RawPrediction 向量中读取的(以及随后在 probability 向量中重新缩放为概率的内容)由模型转换。

但是,通过用训练集中正确分配的观察值的数量对每个类别加权并归一化,可以计算出每棵树中的概率。

使用pySpark,是否有可能访问这些权重,以便人们可以仅使用关于特定观测值的叶节点的信息来为每个观测值重构这些概率(知道权重和叶节点应该足以计算概率) )?谢谢!

有关权重如何工作的解释,请参见this

以下是来自rawPredictions矢量的计算实现的SCALA代码。似乎tree.rootNode.predictImpl(features).impurityStats.stats是我想要访问的。

 override protected def predictRaw(features: Vector): Vector = {
    // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
    // Classifies using majority votes.
    // Ignore the tree weights since all are 1.0 for now.
    val votes = Array.fill[Double](numClasses)(0.0)
    _trees.view.foreach { tree =>
      val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats
      val total = classCounts.sum
      if (total != 0) {
        var i = 0
        while (i < numClasses) {
          votes(i) += classCounts(i) / total
          i += 1
        }
      }
    }
    Vectors.dense(votes)
  }

编辑: 如果无法从pySpark中访问杂质状态,是否可以从随机森林模型中提取单个分类树模型,以便手动重新计算杂质卫星的当量?

编辑2: 我使用here提出的解决方法设法获得了杂质统计信息 从Scala代码和我从该变通办法获得的结果来看,仅通过将根节点中的杂质统计信息用作权重,就应该能够获得正确的概率。正确吗?

0 个答案:

没有答案