在Spark中加载经过培训的crossValidation模型

时间:2016-04-06 09:15:39

标签: apache-spark logistic-regression cross-validation bigdata

我是Apache Spark的新手。我使用crossValidation训练了一个LogisticRegression模型。例如:

  

val cv = new CrossValidator()     .setEstimator(管道)     .setEvaluator(new BinaryClassificationEvaluator)     .setEstimatorParamMaps(paramGrid)     .setNumFolds(5)   val cvModel = cv.fit(data)

我能够毫无错误地训练和测试我的模型。然后我使用以下方法保存模型和管道:

  

cvModel.save(" /路径到我的模型/火花日志REG-转印模型&#34)   pipeline.save(" /路径到我的流水线/火花日志REG-传送管道&#34)

直到这个阶段,这些行动都很完美。然后,我试图加载我的模型以预测新数据点,然后发生以下错误:

  

val sameModel = PipelineModel.load(" / path-to-my-model / spark-log-reg-transfer-model")

java.lang.IllegalArgumentException:要求失败:加载元数据时出错:预期的类名org.apache.spark.ml.PipelineModel但找到了类名org.apache.spark.ml.tuning.CrossValidatorModel

知道我做错了什么吗?感谢。

2 个答案:

答案 0 :(得分:3)

您正在尝试使用PipelineModel对象加载CrossValidator。 你应该使用正确的装载机......

val crossValidator = CrossValidator.load("/path-to-my-model/spark-log-reg-transfer-model")

val sameModel = PipelineModel.load("/path-to-my-pipeline/spark-log-reg-transfer-pipeline")

答案 1 :(得分:1)

要加载交叉验证器,它应该是

val crossValidator = CrossValidator.load("/path-to-my-model/spark-log-reg-transfer-model")

加载交叉验证器模型使用 (注意:当您在CrossValidator上调用fit()时,Cross Validator将成为Cross Validator模型)

val crossValidatorModel = CrossValidatorModel.load("/path-to-my-model/spark-log-reg-transfer-model")

由于您正在尝试加载模型,因此CrossValidatorModel.load将是正确的。