Spark ML Pipeline api保存不起作用

时间:2016-01-11 23:13:22

标签: java apache-spark apache-spark-ml

在版本1.6中,管道api获得了一组新的功能来保存和加载管道阶段。在我训练了分类器并稍后再次加载它以重新使用它并节省计算再次建模的努力之后,我试图将一个阶段保存到磁盘。

出于某种原因,当我保存模型时,该目录仅包含元数据目录。当我尝试再次加载它时,我得到以下异常:

  

线程“main”中的异常java.lang.UnsupportedOperationException:   空集合在   org.apache.spark.rdd.RDD $$ anonfun $ first $ 1.apply(RDD.scala:1330)at at   org.apache.spark.rdd.RDDOperationScope $ .withScope(RDDOperationScope.scala:150)     在   org.apache.spark.rdd.RDDOperationScope $ .withScope(RDDOperationScope.scala:111)     在org.apache.spark.rdd.RDD.withScope(RDD.scala:316)at   org.apache.spark.rdd.RDD.first(RDD.scala:1327)at   org.apache.spark.ml.util.DefaultParamsReader $ .loadMetadata(ReadWrite.scala:284)     在   org.apache.spark.ml.tuning.CrossValidator $ SharedReadWrite $ .load(CrossValidator.scala:287)     在   org.apache.spark.ml.tuning.CrossValidatorModel $ CrossValidatorModelReader.load(CrossValidator.scala:393)     在   org.apache.spark.ml.tuning.CrossValidatorModel $ CrossValidatorModelReader.load(CrossValidator.scala:384)     在   org.apache.spark.ml.util.MLReadable $ class.load(ReadWrite.scala:176)     在   org.apache.spark.ml.tuning.CrossValidatorModel $ .load(CrossValidator.scala:368)     在   org.apache.spark.ml.tuning.CrossValidatorModel.load(CrossValidator.scala)     在   org.test.categoryminer.spark.SparkTextClassifierModelCache.get(SparkTextClassifierModelCache.java:34)

保存我使用的模型:crossValidatorModel.save("/tmp/my.model")

并加载它我使用:CrossValidatorModel.load("/tmp/my.model")

我调用了在CrossValidator对象上调用fit(dataframe)时得到的CrossValidatorModel对象的保存。

任何指针为什么它只保存元数据目录?

1 个答案:

答案 0 :(得分:2)

这肯定不会直接回答你的问题,但我个人并未测试1.6.0中的新功能。

我正在使用专用功能来保存模型。

  def saveCrossValidatorModel(model:CrossValidatorModel, path:String)
  {
    try {
          val fileOut:FileOutputStream  = new FileOutputStream(path)
          val out:ObjectOutputStream  = new ObjectOutputStream(fileOut)
          out.writeObject(model)
          out.close()
          fileOut.close()
      } catch {
        case foe:FileNotFoundException =>
          foe.printStackTrace()
        case ioe:IOException =>
          ioe.printStackTrace()
      }
  }

然后你可以用类似的方式阅读你的模型:

  def loadCrossValidatorModel(path:String): CrossValidatorModel =
  {
    try {
      val fileIn:FileInputStream = new FileInputStream(path)
      val in:ObjectInputStream  = new ObjectInputStream(fileIn)
      val cvModel = in.readObject().asInstanceOf[CrossValidatorModel]
      in.close()
      fileIn.close()
      cvModel
    } catch {
        case foe:FileNotFoundException =>
          foe.printStackTrace()
        case ioe:IOException =>
          ioe.printStackTrace()
      }
  }