计算特定阈值的精度和召回率

时间:2020-02-17 17:58:48

标签: scala apache-spark

我想将逻辑回归的阈值设置为0.5,并且要为管道模型的该值获取精度,召回率,f1分数。但是

model.setThreshold(0.5)

给我

value setThreshold不是以下成员 org.apache.spark.ml.PipelineModel

    val Array(train, test) = dataset
      .randomSplit(Array(0.8, 0.2), seed = 1234L)
      .map(_.cache())
val assembler = new VectorAssembler()
    .setInputCols(Array("label", "id", "features"))
    .setOutputCol("feature")
val pca = new PCA()
    .setInputCol("feature")
    .setK(2)
    .setOutputCol("pcaFeatures")
val classifier = new LogisticRegression()
    .setFeaturesCol("pcaFeatures")
    .setLabelCol("label")
 .setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
val pipeline = new Pipeline().setStages(Array(assembler, pca, classifier))
val model = pipeline.fit(train)
 val predicted = model.transform(test)
  predicted.show()
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.sql.Row
val predictions = predicted.filter(row => row.getAs[Int]("label") == 1).map(row => (row.getAs[Int]("label"), row.getAs[DenseVector]  ("probability")(0)))
predictions.show()
import org.apache.spark.mllib.evaluation.MulticlassMetrics
val predictionAndLabels = predicted.
select($"label",$"prediction").
as[(Double, Double)].
rdd

val metrics = new BinaryClassificationMetrics(predictionAndLabels)

val precision = metrics.precisionByThreshold()
precision.foreach { case (t, p) =>
println(s"Threshold is: $t, Precision is: $p")
    }
    val recall = metrics.recallByThreshold
recall.foreach { case (t, p) =>
println(s"Threshold is: $t,recall is: $p")
}


+---+-------------------+-----+--------------------+--------------------+--------------------+--------------------+----------+
| id|           features|label|             feature|         pcaFeatures|       rawPrediction|         probability|prediction|
+---+-------------------+-----+--------------------+--------------------+--------------------+--------------------+----------+
|  3|0.03731932516607228|    1|[1.0,3.0,0.037319...|[-3.0000000581646...|[-0.8840273374633...|[0.29234391132806...|       1.0|
|  7| 0.9636476860201426|    1|[1.0,7.0,0.963647...|[-7.0000000960209...|[-0.8831455606697...|[0.29252636578097...|       1.0|
|  8| 0.4766320058073684|    0|[0.0,8.0,0.476632...|[-8.0000000194785...|[0.87801311177017...|[0.70641031990863...|       0.0|
| 45| 0.1474318959104205|    1|[1.0,45.0,0.14743...|[-45.000000062664...|[-0.8839183791391...|[0.29236645302163...|       1.0|
|103| 0.3443839885873453|    1|[1.0,103.0,0.3443...|[-103.00000007071...|[-0.8837251994055...|[0.29240642125330...|       1.0|

如何使用管道设置Logistic回归模型的阈值t值?

0 个答案:

没有答案
相关问题