Spark 2 logisticregression删除阈值

时间:2017-05-23 00:46:54

标签: scala apache-spark machine-learning distributed-computing

我使用Spark 2 + Scala训练基于LogisticRegression的二进制分类模型,我使用import org.apache.spark.ml.classification.LogisticRegression,这是Spark 2中新的ml API。但是,当我评估模型时通过AUROC,我没有找到使用概率的方法(在0-1中加倍)而不是二进制分类(0/1)。这是以前由removeThreshold()实现的,但在ml.LogisticRegression我没有找到类似的方法。那么,有没有办法做到这一点?

我使用的评估员是

val evaluator = new BinaryClassificationEvaluator()
  .setLabelCol("label")
  .setRawPredictionCol("rawPrediction")
  .setMetricName("areaUnderROC")
val auroc = evaluator.evaluate(predictions)`

2 个答案:

答案 0 :(得分:0)

如果你想获得0/1输出以外的概率输出,试试这个:

import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression}
val lr = new LogisticRegression()
  .setMaxIter(100)
  .setRegParam(0.3)
val lrModel = lr.fit(trainData)
val summary = lrModel.summary
summary.predictions.select("probability").show()

答案 1 :(得分:0)

import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary,
LogisticRegression}
val lr = new LogisticRegression().setMaxIter(100).setRegParam(0.3)
val lrModel = lr.fit(trainData)  
val trainingSummary = lrModel.summary
val predictions = lrModel.transform(test)
predictions.select("label", "probability").show()