Spark ML - 保存OneVsRestModel

时间:2016-03-27 03:17:30

标签: scala apache-spark apache-spark-mllib apache-spark-ml

我正在重构我的代码以利用DataFrames, Estimators, and Pipelines。我最初在RDD[LabeledPoint]上使用MLlib Multiclass LogisticRegressionWithLBFGS。我很享受学习和使用新API,但我不知道如何保存我的新模型并将其应用于新数据。

目前,LogisticRegression的ML实现仅支持二进制分类。我是,而是使用OneVsRest,如此:

val lr = new LogisticRegression().setFitIntercept(true)
val ovr = new OneVsRest()
ovr.setClassifier(lr)
val ovrModel = ovr.fit(training)

我现在要保存OneVsRestModel,但API似乎不支持此功能。我试过了:

ovrModel.save("my-ovr") // Cannot resolve symbol save
ovrModel.models.foreach(_.save("model-" + _.uid)) // Cannot resolve symbol save

有没有办法保存,所以我可以将它加载到新的应用程序中进行新的预测?

1 个答案:

答案 0 :(得分:5)

Spark 2.0.0

OneVsRestModel实现MLWritable因此应该可以直接保存它。下面显示的方法对于单独保存单个模型仍然很有用。

Spark< 2.0.0

此处的问题是,models会返回Array ClassificationModel[_, _]]而不是Array LogisticRegressionModel(或MLWritable)。为了使其有效,您必须具体说明类型:

import org.apache.spark.ml.classification.LogisticRegressionModel

ovrModel.models.zipWithIndex.foreach { 
  case (model: LogisticRegressionModel, i: Int) => 
    model.save(s"model-${model.uid}-$i")
}

或更通用:

import org.apache.spark.ml.util.MLWritable

ovrModel.models.zipWithIndex.foreach { 
  case (model: MLWritable, i: Int) =>
    model.save(s"model-${model.uid}-$i")
}

不幸的是,就目前而言(Spark 1.6)OneVsRestModel没有实现MLWritable因此无法单独保存。

注意

OneVsRest中的所有模型似乎都使用相同的uid,因此我们需要一个显式索引。稍后识别模型也很有用。