作为paramGrid的一部分,我在logistic回归中使用regParam运行spark ml交叉验证。
val paramGrid = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.1, 0.01))
.build()
val validator = new CrossValidator()
.setEstimator(estimator)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(3)
估算器这里有regParam作为params的一部分。 保存模型的示例代码:
class MyModelWriter(instance: MyModel[T])extends MLWriter {
override protected def saveImpl(path: String): Unit = {
new DefaultParamsWriter(instance).save(path)
instance.model.save(new Path(path, s"nameOfMofel").toString)
}
}
Mymodel确实在params中包含了regParam。
MyModel extends HasRegParam
当我调用model.save(path)时,这是我得到的异常:
java.lang.IllegalArgumentException:要求失败:ValidatorParams save要求estimatorParamMaps中的所有Params应用于此ValidatorParams,其Estimator或其Evaluator。发现了一个无关的Param:logreg_2fb5fdbe5012__regParam 在scala.Predef $ .require(Predef.scala:224) 在org.apache.spark.ml.tuning.ValidatorParams $$ anonfun $ validateParams $ 1 $$ anonfun $ apply $ 1.apply(ValidatorParams.scala:110) 在org.apache.spark.ml.tuning.ValidatorParams $$ anonfun $ validateParams $ 1 $$ anonfun $ apply $ 1.apply(ValidatorParams.scala:109) 在scala.collection.mutable.ResizableArray $ class.foreach(ResizableArray.scala:59) 在scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48) 在org.apache.spark.ml.tuning.ValidatorParams $$ anonfun $ validateParams $ 1.apply(ValidatorParams.scala:109) 在org.apache.spark.ml.tuning.ValidatorParams $$ anonfun $ validateParams $ 1.apply(ValidatorParams.scala:108) 在scala.collection.IndexedSeqOptimized $ class.foreach(IndexedSeqOptimized.scala:33) at scala.collection.mutable.ArrayOps $ ofRef.foreach(ArrayOps.scala:186) at org.apache.spark.ml.tuning.ValidatorParams $ .validateParams(ValidatorParams.scala:108) 在org.apache.spark.ml.tuning.CrossValidatorModel $ CrossValidatorModelWriter。(CrossValidator.scala:257) 在org.apache.spark.ml.tuning.CrossValidatorModel.write(CrossValidator.scala:242) at org.apache.spark.ml.util.MLWritable $ class.save(ReadWrite.scala:157) 在org.apache.spark.ml.tuning.CrossValidatorModel.save(CrossValidator.scala:210) 在com.criteo.lookalike.sink.Sinks $$ anonfun $ SavePipelineParam1 $ 1.apply(Sinks.scala:111
L105上ValidatorParams.scala的代码说
//检查以确保所有Params都适用于此估算工具。如果没有,则抛出错误。
根据这一点,确保estimatorMap中的参数,即在这种情况下的regParam存在于估计器/评估器中,在这种情况下,确实存在于上面的Mymodel中。
任何人都可以告诉我,如果我的理解是正确的,如果是的话,可能是什么导致了这一点?感谢。
答案 0 :(得分:0)
我只是解决了这个确切的错误。
添加网格时,请尝试传递一个Param实例;并在实例化参数时,将其与记录的参数类型相匹配,就像您在https://spark.apache.org/docs/latest/api/scala/下找到的一样。
例如,在RandomForestRegressor
中有numTrees: IntParam
。
因此,我按如下所示构建参数网格...
val rf = new RandomForestRegressor()
.{set...()} // (pseudocode)
val numTrees = new IntParam(rf, "numTrees", "Number of trees to train (>= 1) (default = 20)")
// for fun/preference, i make numTrees[Int] increase as does the area of a circle
val numTreesValues = (for (n <- 3 to 20 by 3) yield (math.Pi * math.pow(n, 2)).toInt)
val paramGrid = new ParamGridBuilder()
.addGrid(numTrees, numTreesValues)
.build()
尝试将估计量传递给参数,然后将参数和值传递给.addGrid
然后我的验证器看起来像这样...
val cv = new CrossValidator()
.setEstimator(rf)
.setEstimatorParamMaps(paramGrid)
.{set...()}