我遇到了多类分类(38个类)的问题,并在Spark ML中实现了管道以解决该问题。这就是我生成模型的方式。
val nb = new NaiveBayes()
.setLabelCol("id")
.setFeaturesCol("features")
.setThresholds(Seq(1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25).toArray)
val pipeline = new Pipeline()
.setStages(Array(stages, assembler, nb))
val model = pipeline.fit(trainingSet)
我希望我的模型只有在其置信度(概率)大于0.8%时才能够预测一个类。
我在Spark文档中进行了大量搜索,以更好地了解阈值参数的含义,但我发现的唯一相关信息就是这一点:
在多类别分类中调整阈值的可能性 预测每个班级。数组的长度必须等于 值大于0的类,但最多一个值可以为0。 预测具有最大值p / t的类别,其中p是原始 该类别的概率,t是该类别的阈值。
这就是为什么我的阈值为1.25的原因。
问题在于,无论我为阈值插入的值是多少,都表明它根本不影响我的预测。
您知道是否有可能仅预测置信度(概率)大于特定阈值的类?
另一种方法是只选择概率大于该阈值的预测,但是我希望可以使用该框架来实现。
谢谢。