MLlib:计算多个阈值的精度和召回率

时间:2016-01-05 22:53:19

标签: scala apache-spark apache-spark-mllib

在我使用它进行评分之前,我将设置逻辑回归的阈值设置为0.5。我现在想获得该值的精确度,召回率,f1分数。不幸的是,当我尝试这样做时,我看到的唯一阈值是1.0和0.0。如何获取0和1以外的阈值的指标。

例如,这里是o / p:

阈值为:1.0,精度为:0.85

阈值为:0.0,精度为:0.312641

我没有得到精度阈值0.5。这是相关的代码。

//我在这里设置Logistic回归模型的阈值。

model.setThreshold(0.5)

// Compute the score and generate an RDD with prediction and label values.  
val predictionAndLabels = data.map { 
  case LabeledPoint(label, features) => (model.predict(features), label)
}

//我现在想要计算精度和召回以及其他指标。由于我已将模型阈值设置为0.5,因此我希望PR达到该值。

val metrics = new BinaryClassificationMetrics(predictionAndLabels)
val precision = metrics.precisionByThreshold()

precision.foreach { 
  case (t, p) => {
    println(s"Threshold is: $t, Precision is: $p")

    if (t == 0.5) {
      println(s"Desired: Threshold is: $t, Precision is: $p")        
    }
}

1 个答案:

答案 0 :(得分:1)

precisionByThreshold()方法实际上尝试不同的阈值并给出相应的精度值。由于您已经对数据进行了阈值处理,因此只有0和1。

让我们说你有: 阈值和真实标签之后的[0 0 0 1 1 1] [f f f f t t]

然后使用0进行阈值处理,[t t t t t t]给出4个误报和2个正数,因此精度为2 / (2 + 4) = 1/3

现在使用1进行阈值处理,你有[f f f t t t],并且给出1个误报和2个正数,因此精度为2 /(2 + 1) = 2/3

你可以看到使用.5的阈值现在会给你[f f f t t t],与使用1的阈值相同,所以它是你正在寻找的阈值1的精度。

这有点令人困惑,因为您已经对预测进行了阈值处理。如果您没有对预测进行阈值处理,请假设您有[.3 .4 .4 .6 .8 .9](与我一直使用的[0 0 0 1 1 1]保持一致)。

然后precisionByThreshold()会给出阈值0,.3,.4,.6 .8 .9的精度值,因为这些都是给出不同结果并因此得到不同精度的阈值,并获得值对于阈值.5,您仍然会获取下一个较大阈值(.6)的值,因为再次,它将给出相同的预测,因此具有相同的精度。