如何从Spark MLlib计算原始分数推断出预测的类标签

时间:2017-04-07 18:24:10

标签: scala apache-spark apache-spark-mllib apache-spark-ml

阅读下面的Spark文档

https://spark.apache.org/docs/latest/mllib-optimization.html

以下二进制分类预测的示例代码段:

    val model = new LogisticRegressionModel(
    Vectors.dense(weightsWithIntercept.toArray.slice(0,weightsWithIntercept.size - 1)),
    weightsWithIntercept(weightsWithIntercept.size - 1))

    // Clear the default threshold.
    model.clearThreshold()

   // Compute raw scores on the test set.
   val scoreAndLabels = test.map { point =>
   val score = model.predict(point.features)
   (score, point.label)

如您所见,model.prediction(point.features)返回原始分数,即超平面分离距离的边距。

我的问题是:

(1)如何根据上述计算的原始分数知道预测类标签是0还是1?

(2)如何从上面计算的原始分数中推断出这个二元分类案例中的预测类标签(0或1)?

1 个答案:

答案 0 :(得分:3)

默认情况下,阈值为0.5,因此在使用BinaryClassificationMetrics时,如果分数为< {0},则会给出类标签0。如果它更高,则为0.5和1。所以你也可以这样做,从分数中推断出这个类。