如何在Spark MLlib分类器中找到预测类的概率?

时间:2015-07-06 18:23:19

标签: apache-spark classification probability apache-spark-mllib

Spark MLlib提供了几种分类算法,例如Random Forests和Logistic Regression。分类器训练和类预测的示例是直截了当的。然而,尚不清楚使用什么分类器API来获得给定实例属于预测类的概率。例如对于随机森林分类器:

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.log4j.Logger
import org.apache.log4j.Level
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.util.MLUtils

object RFClassifier {

  def main(args: Array[String]) {

    Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
    // set up environment
    val conf = new SparkConf()
      .setMaster("local[5]")
      .setAppName("RFClassifier")
      .set("spark.executor.memory", "2g")
    val sc = new SparkContext(conf)

    // Load and parse the data file.
    val data = MLUtils.loadLibSVMFile(sc, "in/sample_libsvm_data.txt")
    // Split the data into training and test sets (30% held out for testing)
    val splits = data.randomSplit(Array(0.7, 0.3))
    val (trainingData, testData) = (splits(0), splits(1))

    // Train a RandomForest model.
    //  Empty categoricalFeaturesInfo indicates all features are continuous.
    val numClasses = 2
    val categoricalFeaturesInfo = Map[Int, Int]()
    val numTrees = 3 // Use more in practice.
    val featureSubsetStrategy = "auto" // Let the algorithm choose.
    val impurity = "gini"
    val maxDepth = 4
    val maxBins = 32

    val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
      numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)

    // Evaluate model on test instances and compute test error
    val labelAndPreds = testData.map { point =>
      val prediction = model.predict(point.features)
      (point.label, prediction)
    }
    val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
    println("Test Error = " + testErr)
    println("Learned classification forest model:\n" + model.toDebugString)

    // Save and load model
    model.save(sc, "RFClassifierModel")
    val sameModel = RandomForestModel.load(sc, "RFClassifierModel")
  }
}

如何找出预测类的概率?其他分类器也存在同样的问题。有任何想法吗?谢谢!

更新

作为一个粗略的解决方法:要使用的每种可能类型的分类器首先需要通过训练集进行训练。训练完成后,总能找到该训练集中正确预测的百分比。这个百分比可以用作任何实例属于预测类的概率的原始估计吗?例如,如果对于给定的分类器,我们在训练集中得到80%的正确预测,我们是否可以假设具有给定类的实例的平均概率对于该分类器是0.8?

0 个答案:

没有答案