Spark MLlib仅在阈值大于值时预测

时间:2018-10-30 09:40:20

标签: apache-spark machine-learning apache-spark-mllib

我遇到了多类分类(38个类)的问题,并在Spark ML中实现了管道以解决该问题。这就是我生成模型的方式。

val nb = new NaiveBayes()
  .setLabelCol("id")
  .setFeaturesCol("features")
  .setThresholds(Seq(1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25).toArray)

val pipeline = new Pipeline()
  .setStages(Array(stages, assembler, nb))

val model = pipeline.fit(trainingSet)

我希望我的模型只有在其置信度(概率)大于0.8%时才能够预测一个类。

我在Spark文档中进行了大量搜索,以更好地了解阈值参数的含义,但我发现的唯一相关信息就是这一点:

  

在多类别分类中调整阈值的可能性   预测每个班级。数组的长度必须等于   值大于0的类,但最多一个值可以为0。   预测具有最大值p / t的类别,其中p是原始   该类别的概率,t是该类别的阈值。

这就是为什么我的阈值为1.25的原因。

问题在于,无论我为阈值插入的值是多少,都表明它根本不影响我的预测。

您知道是否有可能仅预测置信度(概率)大于特定阈值的类?

另一种方法是只选择概率大于该阈值的预测,但是我希望可以使用该框架来实现。

谢谢。

0 个答案:

没有答案