如何仅获得概率大于x的预测

时间:2020-06-03 16:00:23

标签: random-forest apache-spark-mllib

我使用随机森林将文本分类为某些类别。使用测试数据时,我的精度为0.98。但是,使用另一组数据时,总体精度降低到0.7。我认为,大多数行仍具有较高的准确性。

因此,现在我只想高度显示预测类别。 random-forrest给了我一列“概率”,它是概率的数组。如何获得所选预测的实际概率?

val randomForrest = new RandomForestClassifier()
      .setLabelCol(labelIndexer.getOutputCol)
      .setFeaturesCol(vectorAssembler.getOutputCol)
      .setProbabilityCol("probability")
      .setSeed(123)
      .setPredictionCol("prediction")

1 个答案:

答案 0 :(得分:0)

我最终想出了以下udf以获得最佳预测及其概率。 如果有更方便的方法,请发表评论。

def getBestPrediction = udf((
  rawPrediction: org.apache.spark.ml.linalg.Vector, probability: org.apache.spark.ml.linalg.Vector) => {
  val bestPrediction = probability.argmax
  val bestProbability = probability(bestPrediction)     
  (bestPrediction, bestProbability)
})