如何使MLReader上的泛型函数

时间:2019-03-08 14:05:09

标签: scala apache-spark apache-spark-ml

我正在使用Spark 1.6.3。这是两个功能相同的函数:

def modelFromBytesCV(modelArray: Array[Byte]): CountVectorizerModel = {
  val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
  Files.write(tempPath, modelArray)
  CountVectorizerModel.read.load(tempPath.toString)
}

def modelFromBytesIDF(modelArray: Array[Byte]): IDFModel = {
  val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
  Files.write(tempPath, modelArray)
  IDFModel.read.load(tempPath.toString)
}

我想使这些功能通用。我挂在嘴上的是CountVectorizerModel对象和IDFModel之间的共同特征是MLReadable [T],它本身必须作为CountVectorizerModel或IDFModel作为类型。这是一个递归的父类循环,我无法解决。

通过比较,通用模型编写器很容易,因为MLWritable是我感兴趣的所有模型扩展的共同特征:

def modelToBytes[M <: MLWritable](model: M): Array[Byte] = {
  val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
  model.write.overwrite().save(tempPath.toString)
  Files.readAllBytes(tempPath)
}

我如何制作一个通用的读取器,该读取器会将spark-ml模型转换为字节数组?

1 个答案:

答案 0 :(得分:2)

要使其正常工作,您需要访问特定的MlReadable对象。

import org.apache.spark.ml.util.MLReadable

def modelFromBytes[M](obj: MLReadable[M], modelArray: Array[Byte]): M = {
  val tempPath: Path = ???
  ...
  obj.read.load(tempPath.toString)
}

稍后可以用作:

val bytes: Array[Byte] = ???
modelFromBytes(CountVectorizerModel, bytes)

请注意,尽管首次出现,但这里没有递归-MLReadable[M]指的是伴随对象,而不是类。例如,CountVectorizerModel objectMLReadable,而CountVectorizeModel class不是。

在内部,Spark MLReader以不同的方式进行处理-it creates an instance of the class using reflection,然后是sets its Params。但是,此路径在这里对您不是很有用*。

如果需要与当前API兼容,则可以尝试使可读对象隐式:

def modelFromBytes[M](modelArray: Array[Byte])(implicit obj: MLReadable[M]): M = {
  ...
}

然后

implicit val readable: MLReadable[CountVectorizerModel] = CountVectorizerModel

modelFromBytes[CountVectorizerModel](bytes)

*从技术上讲,可以通过反射获得伴侣对象

def modelFromBytesCV[M <: MLWritable](
    modelArray: Array[Byte])(implicit ct: ClassTag[M]): M = {
  val tempPath: Path = ???
  ...
  val cls = Class.forName(ct.runtimeClass.getName + "$");
  cls.getField("MODULE$").get(cls).asInstanceOf[MLReadable[M]]
    .read.load(tempPath.toString)) 
}

但是我认为这不是值得探索的路径。特别是,我们在这里不能真正提供严格的类型界限-使用MLWritable是一种限制人为错误的技巧,但是对于编译器却毫无用处。