Spark K-fold交叉验证

时间:2016-06-20 09:43:40

标签: machine-learning classification apache-spark-mllib cross-validation

我在理解Spark的交叉验证方面遇到了一些麻烦。我见过的任何一个例子都用它来进行参数调整,但我认为它只会进行常规的K折交叉验证吗?

我想要做的是进行k折交叉验证,其中k = 5。我想获得每个结果的准确性,然后获得平均准确度。 在scikit中学习这是怎么做的,分数会给你每个折叠的结果,然后你可以使用scores.mean()

scores = cross_val_score(classifier, y, x, cv=5, scoring='accuracy')

这就是我在Spark中的做法,paramGridBuilder是空的,因为我不想输入任何参数。

val paramGrid = new ParamGridBuilder().build()
val evaluator = new MulticlassClassificationEvaluator()
  evaluator.setLabelCol("label")
  evaluator.setPredictionCol("prediction")
evaluator.setMetricName("precision")


val crossval = new CrossValidator()
crossval.setEstimator(classifier)
crossval.setEvaluator(evaluator) 
crossval.setEstimatorParamMaps(paramGrid)
crossval.setNumFolds(5)


val modelCV = crossval.fit(df4)
val chk = modelCV.avgMetrics

这是否与scikit学习实施相同?为什么这些示例在进行交叉验证时会使用培训/测试数据?

How to cross validate RandomForest model?

https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala

1 个答案:

答案 0 :(得分:3)

  1. 你正在做的事情看起来不错。
  2. 基本上,是的,它的作用与sklearn的grid search CV相同 对于每个EstimatorParamMaps(一组参数),使用CV测试算法,因此DECLARE @T TABLE ( [ID] INT, [X] INT, [Y] INT ) INSERT INTO @T SELECT [ID], [X], [Y] FROM [Redaction] AS R WHERE [ID] IN ( SELECT [PrimaryID] FROM [LinkedRedactions] ) UPDATE [Redaction] SET [X] = @T.[X], [Y] = @T.[Y] WHERE [Redaction].[ID] IN ( SELECT [ID] FROM @T ) 是所有折叠上的平均交叉验证精度度量/ s。 如果一个人使用空avgMetrics(没有参数搜索),它就像是"常规"交叉验证"我们将产生一个交叉验证的培训准确性。
  3. 每个CV迭代包括ParamGridBuilder训练折叠和K-1测试折叠,那么为什么大多数示例在进行交叉验证之前将数据分离为训练/测试数据? 因为CV内的测试折叠用于参数网格搜索。 这意味着模型选择需要额外的验证数据集。 所谓的"测试数据集"需要评估最终模型。阅读更多here