我在Spark中定义了一个新的自定义UnaryTransformer(示例代码中的cleanText),并将其用于管道中。当我保存拟合的管道并尝试将其读回时,出现以下错误:
java.lang.NoSuchMethodException:test_job $ cleanText.read()
当我仅保存和加载一元变压器时,它就可以正常工作。
重现该错误的示例代码(在Spark 2.2中测试):
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.DoubleParam
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.types._
import org.apache.spark.ml.{PipelineModel}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{DataType, DataTypes}
import org.apache.spark.util.Utils
import org.apache.spark.ml.{Pipeline, PipelineStage}
import org.apache.spark.ml.param._
object test_job {
class cleanText(override val uid: String) extends UnaryTransformer[String, String, cleanText] with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("cleantext"))
override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == StringType)
}
protected def createTransformFunc: String => String = {
val regex = "[^a-zA-Z0-9]".r
s => regex.replaceAllIn(s,m=>" ")
}
protected def outputDataType: DataType = StringType
}
object cleanText extends DefaultParamsReadable[cleanText]
//{
// override def load(path: String): cleanText = super.load(path)
//}
def main(args: Array[String]) {
val sc: SparkContext = new SparkContext(new SparkConf().setAppName("test_job"))
val sqlc = SparkSession.builder.appName("test_job").getOrCreate()
import sqlc.implicits._
val cleaner = new cleanText()
cleaner.setInputCol("word").setOutputCol("r_clean")
val someDF = sc.parallelize(Seq(
(1, "sample text 1"),
(2, "sample text 2"),
(3, "sample text 3")
)).toDF("number", "word")
val pipeline = new Pipeline().setStages(Array(cleaner))
val pipeline_fitted = pipeline.fit(someDF)
pipeline_fitted.write.overwrite().save("/tmp/model/")
//Saving just the transformer
//cleaner.write.overwrite().save("/tmp/model/")
println("Pipeline saved")
val pl2 = PipelineModel.load("/tmp/model/")
//Loading just the transformer will work
//val cln = cleanText.load("/tmp/model/")
println("Pipeline loaded")
sqlc.stop()
}
}