我使用随机森林将文本分类为某些类别。使用测试数据时,我的精度为0.98。但是,使用另一组数据时,总体精度降低到0.7。我认为,大多数行仍具有较高的准确性。
因此,现在我只想高度显示预测类别。 random-forrest给了我一列“概率”,它是概率的数组。如何获得所选预测的实际概率?
val randomForrest = new RandomForestClassifier()
.setLabelCol(labelIndexer.getOutputCol)
.setFeaturesCol(vectorAssembler.getOutputCol)
.setProbabilityCol("probability")
.setSeed(123)
.setPredictionCol("prediction")
答案 0 :(得分:0)
我最终想出了以下udf以获得最佳预测及其概率。 如果有更方便的方法,请发表评论。
def getBestPrediction = udf((
rawPrediction: org.apache.spark.ml.linalg.Vector, probability: org.apache.spark.ml.linalg.Vector) => {
val bestPrediction = probability.argmax
val bestProbability = probability(bestPrediction)
(bestPrediction, bestProbability)
})