如何在MLlib中编写自定义Transformer?

时间:2016-11-15 17:07:47

标签: scala apache-spark apache-spark-sql apache-spark-mllib

我想为scala中的spark 2.0中的管道编写自定义Transformer。到目前为止,我并不清楚copytransformSchema方法应该返回什么。他们返回null是否正确? https://github.com/SupunS/play-ground/blob/master/test.spark.client_2/src/main/java/CustomTransformer.java复制?

随着Transformer扩展PipelineStage,我得出结论,fit调用了transformSchema方法。我是否正确理解transformSchema类似于sk-learnns fit?

由于我的Transformer应该将数据集与(非常小的)第二个数据集连接起来,我想将其存储在序列化管道中。我应该如何将其存储在变压器中以正确使用管道序列化机制?

一个简单的变换器如何计算单个列的平均值并填充nan值+持续这个值?

@SerialVersionUID(serialVersionUID) // TODO store ibanList in copy + persist
    class Preprocessor2(someValue: Dataset[SomeOtherValues]) extends Transformer {

      def transform(df: Dataset[MyClass]): DataFrame = {

      }

      override def copy(extra: ParamMap): Transformer = {
      }

      override def transformSchema(schema: StructType): StructType = {
        schema
      }
    }

2 个答案:

答案 0 :(得分:3)

transformSchema应该返回应用Transformer后预期的架构。例如:

  • 如果transfomer添加IntegerType列,输出列名称为foo

    import org.apache.spark.sql.types._
    
    override def transformSchema(schema: StructType): StructType = {
       schema.add(StructField("foo", IntegerType))
    }
    
  

因此,如果没有为数据集更改模式,因为只填充了名称值以进行均值插补,我应该将原始案例类作为模式返回吗?

在Spark SQL(以及MLlib)中也是不可能的,因为Dataset一旦创建就不可变。您只能添加或“替换”(添加后跟drop操作)列。

答案 1 :(得分:2)

首先,我不确定你想要Transformer本身(或UnaryTransformer@LostInOverflow suggested in the answer),如你所说:

  

一个简单的变换器如何计算单个列的平均值并填充nan值+持续这个值?

对我来说,好像你想要应用聚合函数(也就是聚合)和"加入"它与所有列产生最终值或NaN。

看起来就像你希望groupBymean进行聚合,然后join也可以进行窗口聚合。

无论如何,我从UnaryTransformer开始,这将解决你问题中的第一个问题:

  

到目前为止,我并不清楚copytransformSchema方法应该返回什么。它们返回null是否正确?

请参阅the complete project spark-mllib-custom-transformer at GitHub,其中我实现了UnaryTransformertoUpperCase字符串列,其中UnaryTransformer的内容如下所示:

import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.{DataType, StringType}

class UpperTransformer(override val uid: String)
  extends UnaryTransformer[String, String, UpperTransformer] {

  def this() = this(Identifiable.randomUID("upp"))

  override protected def createTransformFunc: String => String = {
    _.toUpperCase
  }

  override protected def outputDataType: DataType = StringType
}