如何使用CrossValidator进行精确/召回以使用Spark训练NaiveBayes模型

时间:2016-06-12 19:59:31

标签: apache-spark apache-spark-mllib apache-spark-ml apache-spark-1.5

Supossed我有这样的管道:

val tokenizer = new Tokenizer().setInputCol("tweet").setOutputCol("words")
val hashingTF = new HashingTF().setNumFeatures(1000).setInputCol("words").setOutputCol("features")
val idf = new IDF().setInputCol("features").setOutputCol("idffeatures")
val nb = new org.apache.spark.ml.classification.NaiveBayes()
val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, idf, nb))
val paramGrid = new ParamGridBuilder().addGrid(hashingTF.numFeatures, Array(10, 100, 1000)).addGrid(nb.smoothing, Array(0.01, 0.1, 1)).build()
val cv = new CrossValidator().setEstimator(pipeline).setEvaluator(new BinaryClassificationEvaluator()).setEstimatorParamMaps(paramGrid).setNumFolds(10)
val cvModel = cv.fit(df)

如您所见,我使用MultiClassClassificationEvaluator定义了一个CrossValidator。我已经看到很多示例在测试过程中获得像Precision / Recall这样的指标,但是当您使用不同的数据集进行测试时会得到这些指标(例如参见此documentation)。

根据我的理解,CrossValidator将创建折叠,一次折叠将用于测试目的,然后CrossValidator将选择最佳模型。我的问题是,在培训过程中可以获得精确/召回指标吗?

1 个答案:

答案 0 :(得分:2)

嗯,实际存储的唯一指标是您在创建Evaluator实例时定义的指标。对于BinaryClassificationEvaluator,这可以采用以下两个值中的一个:

  • areaUnderROC
  • areaUnderPR

前一个是默认值,可以使用setMetricName方法设置。

这些值是在培训过程中收集的,可以使用CrossValidatorModel.avgMetrics进行访问。值的顺序对应于EstimatorParamMapsCrossValidatorModel.getEstimatorParamMaps)的顺序。