无法在Spark中使用自定义一元转换器读取管道模型

时间:2018-07-23 09:53:38

标签: apache-spark apache-spark-mllib

我在Spark中定义了一个新的自定义UnaryTransformer(示例代码中的cleanText),并将其用于管道中。当我保存拟合的管道并尝试将其读回时,出现以下错误:

  

java.lang.NoSuchMethodException:test_job $ cleanText.read()

当我仅保存和加载一元变压器时,它就可以正常工作。

重现该错误的示例代码(在Spark 2.2中测试):

import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.DoubleParam
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.types._
import org.apache.spark.ml.{PipelineModel}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{DataType, DataTypes}
import org.apache.spark.util.Utils
import org.apache.spark.ml.{Pipeline, PipelineStage}
import org.apache.spark.ml.param._


object test_job {

    class cleanText(override val uid: String) extends UnaryTransformer[String, String, cleanText] with DefaultParamsWritable {

    def this() = this(Identifiable.randomUID("cleantext"))

    override protected def validateInputType(inputType: DataType): Unit = {
        require(inputType == StringType)
    }

     protected def createTransformFunc: String => String = {
        val regex = "[^a-zA-Z0-9]".r
        s => regex.replaceAllIn(s,m=>" ")
     }

     protected def outputDataType: DataType = StringType

    }

    object cleanText extends DefaultParamsReadable[cleanText]
    //{
    //  override def load(path: String): cleanText = super.load(path)
    //}

    def main(args: Array[String]) {
          val sc: SparkContext = new SparkContext(new SparkConf().setAppName("test_job"))
    val sqlc = SparkSession.builder.appName("test_job").getOrCreate()
    import sqlc.implicits._

    val cleaner = new cleanText()
    cleaner.setInputCol("word").setOutputCol("r_clean")

    val someDF = sc.parallelize(Seq(
        (1, "sample text 1"),
        (2, "sample text 2"),
        (3, "sample text 3")
        )).toDF("number", "word")

    val pipeline = new Pipeline().setStages(Array(cleaner))

    val pipeline_fitted = pipeline.fit(someDF)
    pipeline_fitted.write.overwrite().save("/tmp/model/")
    //Saving just the transformer
    //cleaner.write.overwrite().save("/tmp/model/")
    println("Pipeline saved")

    val pl2 = PipelineModel.load("/tmp/model/")
    //Loading just the transformer will work
    //val cln = cleanText.load("/tmp/model/")

    println("Pipeline loaded")
    sqlc.stop()
  }

}

0 个答案:

没有答案