Spark自定义估算器,包括持久性

时间:2016-11-26 10:38:12

标签: apache-spark apache-spark-sql pipeline apache-spark-mllib apache-spark-ml

我想为spark开发一个自定义估算器,它也可以处理大型管道API的持久性。但是正如How to Roll a Custom Estimator in PySpark mllib所说,那里还没有很多文件(<)。

我有一些用spark编写的数据清理代码,并希望将其包装在自定义估算器中。包括一些na替换,列删除,过滤和基本特征生成(例如,出生日期到年龄)。

  • transformSchema将使用数据集ScalaReflection.schemaFor[MyClass].dataType.asInstanceOf[StructType]
  • 的案例类
  • 适合只适合于平均年龄为na。替代

我还不清楚:

    自定义管道模型中的
  • transform将用于转换&#34; fit&#34;关于新数据的估算器。它是否正确?如果是,我应该如何转移拟合值,例如从上面进入模型的平均年龄?

  • 如何处理持久性?我在私有spark组件中找到了一些通用的loadImpl方法,但我不确定如何传输我自己的参数,例如MLReader / MLWriter中用于序列化的平均年龄。

如果你可以帮助我使用自定义估算器,那将会很棒 - 特别是对于持久性部分。

2 个答案:

答案 0 :(得分:2)

首先,我相信你会混合两件不同的东西:

  • Estimators - 代表可以fit的阶段。 Estimator fit方法需要Dataset并返回Transformer(模型)。
  • Transformers - 代表可以transform数据的阶段。

当您fit Pipeline fits全部Estimators并返回PipelineModel时。 PipelineModel可以transform数据在模型中的所有transform上依次调用Transformers

  

我应该如何转移拟合值

这个问题没有一个答案。通常,您有两种选择:

  • 将拟合模型的参数作为Transformer
  • 的参数传递
  • 制作Transformer
  • 的拟合模型Params的参数

第一种方法通常由内置Transformer使用,但第二种方法应该在一些简单的情况下使用。

  

如何处理持久性

  • 如果Transformer仅由其Params定义,则可以延长DefaultParamsReadable
  • 如果您使用更复杂的参数,则应扩展MLWritable并实施对您的数据有意义的MLWriter。 Spark源代码中有多个示例,展示了如何实现数据和元数据的读/写。

如果您正在寻找易于理解的示例,请查看CountVectorizer(Model)其中:

答案 1 :(得分:2)

以下使用 Scala API ,但如果您真的想要,可以轻松地将其重构为Python。

首先要做的事情:

  • Estimator :实现.fit()返回Transformer
  • Transformer :实现.transform()并操纵DataFrame
  • 序列化/反序列化:尽量使用内置Params并利用简单的DefaultParamsWritable 特征 + 配套对象扩展DefaultParamsReadable[T]。 a.k.a远离MLReader / MLWriter并保持代码简单。
  • 参数传递:使用扩展Params的公共特征并在Estimator和Model(a.k.a.Transformer)之间共享

骨架代码:

// Common Parameters
trait MyCommonParams extends Params {
  final val inputCols: StringArrayParam = // usage: new MyMeanValueStuff().setInputCols(...)
    new StringArrayParam(this, "inputCols", "doc...")
    def setInputCols(value: Array[String]): this.type = set(inputCols, value)
    def getInputCols: Array[String] = $(inputCols)

  final val meanValues: DoubleArrayParam = 
    new DoubleArrayParam(this, "meanValues", "doc...")
    // more setters and getters
}

// Estimator
class MyMeanValueStuff(override val uid: String) extends Estimator[MyMeanValueStuffModel] 
  with DefaultParamsWritable // Enables Serialization of MyCommonParams
  with MyCommonParams {

  override def copy(extra: ParamMap): Estimator[MeanValueFillerModel] = defaultCopy(extra) // deafult
  override def transformSchema(schema: StructType): StructType = schema // no changes
  override def fit(dataset: Dataset[_]): MyMeanValueStuffModel = {
    // your logic here. I can't do all the work for you! ;)
   this.setMeanValues(meanValues)
   copyValues(new MyMeanValueStuffModel(uid + "_model").setParent(this))
  }
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuff extends DefaultParamsReadable[MyMeanValueStuff]

// Model (Transformer)
class MyMeanValueStuffModel(override val uid: String) extends Model[MyMeanValueStuffModel] 
  with DefaultParamsWritable // Enables Serialization of MyCommonParams
  with MyCommonParams {

  override def copy(extra: ParamMap): MyMeanValueStuffModel = defaultCopy(extra) // default
  override def transformSchema(schema: StructType): StructType = schema // no changes
  override def transform(dataset: Dataset[_]): DataFrame = {
      // your logic here: zip inputCols and meanValues, toMap, replace nulls with NA functions
      // you have access to both inputCols and meanValues here!
  }
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuffModel extends DefaultParamsReadable[MyMeanValueStuffModel]

使用上面的代码,您可以序列化/反序列化包含MyMeanValueStuff阶段的管道。

想看一下Estimator的一些简单实现吗? MinMaxScaler! (我的例子实际上更简单......)