Spark中的UDAF具有多个输入列

时间:2016-05-12 16:57:34

标签: scala apache-spark user-defined-functions

我正在尝试开发用户定义的聚合函数,该函数计算一行数字的线性回归。我已经成功地完成了一个计算均值置信区间的UDAF(有很多试验和错误以及SO!)。

以下是我实际运行的内容:

import org.apache.spark.sql._
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{StructType, StructField, DoubleType, LongType, DataType, ArrayType}

case class RegressionData(intercept: Double, slope: Double)

class Regression  {

  import org.apache.commons.math3.stat.regression.SimpleRegression

  def roundAt(p: Int)(n: Double): Double = { val s = math pow (10, p); (math round n * s) / s }

  def getRegression(data: List[Long]): RegressionData = {
    val regression: SimpleRegression  = new SimpleRegression()
    data.view.zipWithIndex.foreach { d =>
        regression.addData(d._2.toDouble, d._1.toDouble)
    }

    RegressionData(roundAt(3)(regression.getIntercept()), roundAt(3)(regression.getSlope()))
  }
}


class UDAFRegression extends UserDefinedAggregateFunction {

  import java.util.ArrayList

  def deterministic = true

  def inputSchema: StructType =
    new StructType().add("units", LongType)

  def bufferSchema: StructType =
    new StructType().add("buff", ArrayType(LongType))


  def dataType: DataType =
    new StructType()
      .add("intercept", DoubleType)
      .add("slope", DoubleType)

  def initialize(buffer: MutableAggregationBuffer) = {
    buffer.update(0, new ArrayList[Long]())
  }

  def update(buffer: MutableAggregationBuffer, input: Row) = {
    val longList: ArrayList[Long]  = new ArrayList[Long](buffer.getList(0))
    longList.add(input.getLong(0));
    buffer.update(0, longList);

  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    val longList: ArrayList[Long] = new ArrayList[Long](buffer1.getList(0))
    longList.addAll(buffer2.getList(0))

    buffer1.update(0, longList)
  }


  def evaluate(buffer: Row) = {
    import scala.collection.JavaConverters._
    val list = buffer.getList(0).asScala.toList
    val regression = new Regression
    regression.getRegression(list)
  }
}

然而,数据集不是按顺序排列的,这在这里显然非常重要。因此,我需要第二个参数regression($"longValue")而不是regression($"longValue", $"created_day")created_daysql.types.DateType

我对DataTypes,StructTypes和what-not感到困惑,由于网络上缺少示例,我在这里试用了我的试用和订单。

我的bufferSchema会是什么样的?

在我的情况下,这些StructTypes是否有开销?一个(可变的)Map不会这么做吗? MapType实际上是不可变的,并且这不是一个毫无意义的缓冲类型吗?

我的inputSchema会是什么样的?

这是否必须与我在案例update() input.getLong(0)中检索的类型相匹配?

是否有一种标准方法如何在initialize()

中重置缓冲区

我看过buffer.update(0, 0.0)(当它包含双打时,很明显),buffer(0) = new WhatEver()我甚至认为buffer = Nil。这些都有所不同吗?

如何更新数据?

上面的例子似乎过于复杂。我原以为能够做某事。比如buffer += input.getLong(0) -> input.getDate(1)。 我可以期望以这种方式访问​​输入

如何合并数据?

我可以将功能块留空吗? def merge(…) = {}

evaluate()中对缓冲区进行排序的挑战是......我应该能够弄清楚,虽然我仍然对你们如何做到这一点的最优雅的方式感兴趣(在很短的时间内)。

奖金问题:dataType扮演什么角色?

我返回一个案例类,而不是StructType中定义的dataType,这似乎不是问题。或者它是否有效,因为它碰巧匹配我的案例类?

1 个答案:

答案 0 :(得分:2)

也许这会让事情变得清晰。

UDAF APIs工作DataFrame Columns。您正在做的所有操作都必须像Columns中的所有其他DataFrame一样进行序列化。如您所知,唯一的支持MapType是不可变的,因为这是您可以放在Column中的唯一内容。使用不可变集合,您只需创建一个包含旧集合和值的新集合:

var map = Map[Long,Long]()
map = map + (0L -> 1234L)
map = map + (1L -> 4567L)

是的,就像使用任何DataFrame一样,您的类型必须匹配。 buffer.getInt(0)LongType确实存在问题时,请null

没有标准方法可以重置缓冲区,因为除了对数据类型/用例有意义的其他方法。也许零实际上是上个月的balanace;也许零是另一个数据集的运行平均值;也许零是merge或空字符串,或者可能为零实际上为零。

update是一种优化,只有在某些情况下才会发生,如果我没记错的话 - 如果环境允许,可以使用SQL优化可以使用的一种方法。我只使用我用于case class的相同功能。

dataType将自动转换为适当的架构,因此对于奖励问题,答案是肯定的,因为架构匹配。将mybatis-mapper.xml更改为不匹配,您将收到错误。