我正在尝试开发用户定义的聚合函数,该函数计算一行数字的线性回归。我已经成功地完成了一个计算均值置信区间的UDAF(有很多试验和错误以及SO!)。
以下是我实际运行的内容:
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{StructType, StructField, DoubleType, LongType, DataType, ArrayType}
case class RegressionData(intercept: Double, slope: Double)
class Regression {
import org.apache.commons.math3.stat.regression.SimpleRegression
def roundAt(p: Int)(n: Double): Double = { val s = math pow (10, p); (math round n * s) / s }
def getRegression(data: List[Long]): RegressionData = {
val regression: SimpleRegression = new SimpleRegression()
data.view.zipWithIndex.foreach { d =>
regression.addData(d._2.toDouble, d._1.toDouble)
}
RegressionData(roundAt(3)(regression.getIntercept()), roundAt(3)(regression.getSlope()))
}
}
class UDAFRegression extends UserDefinedAggregateFunction {
import java.util.ArrayList
def deterministic = true
def inputSchema: StructType =
new StructType().add("units", LongType)
def bufferSchema: StructType =
new StructType().add("buff", ArrayType(LongType))
def dataType: DataType =
new StructType()
.add("intercept", DoubleType)
.add("slope", DoubleType)
def initialize(buffer: MutableAggregationBuffer) = {
buffer.update(0, new ArrayList[Long]())
}
def update(buffer: MutableAggregationBuffer, input: Row) = {
val longList: ArrayList[Long] = new ArrayList[Long](buffer.getList(0))
longList.add(input.getLong(0));
buffer.update(0, longList);
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
val longList: ArrayList[Long] = new ArrayList[Long](buffer1.getList(0))
longList.addAll(buffer2.getList(0))
buffer1.update(0, longList)
}
def evaluate(buffer: Row) = {
import scala.collection.JavaConverters._
val list = buffer.getList(0).asScala.toList
val regression = new Regression
regression.getRegression(list)
}
}
然而,数据集不是按顺序排列的,这在这里显然非常重要。因此,我需要第二个参数regression($"longValue")
而不是regression($"longValue", $"created_day")
。 created_day
是sql.types.DateType
。
我对DataTypes,StructTypes和what-not感到困惑,由于网络上缺少示例,我在这里试用了我的试用和订单。
bufferSchema
会是什么样的?在我的情况下,这些StructTypes是否有开销?一个(可变的)Map
不会这么做吗? MapType
实际上是不可变的,并且这不是一个毫无意义的缓冲类型吗?
inputSchema
会是什么样的?这是否必须与我在案例update()
input.getLong(0)
中检索的类型相匹配?
initialize()
我看过buffer.update(0, 0.0)
(当它包含双打时,很明显),buffer(0) = new WhatEver()
我甚至认为buffer = Nil
。这些都有所不同吗?
上面的例子似乎过于复杂。我原以为能够做某事。比如buffer += input.getLong(0) -> input.getDate(1)
。
我可以期望以这种方式访问输入
我可以将功能块留空吗?
def merge(…) = {}
?
在evaluate()
中对缓冲区进行排序的挑战是......我应该能够弄清楚,虽然我仍然对你们如何做到这一点的最优雅的方式感兴趣(在很短的时间内)。
dataType
扮演什么角色?我返回一个案例类,而不是StructType
中定义的dataType
,这似乎不是问题。或者它是否有效,因为它碰巧匹配我的案例类?
答案 0 :(得分:2)
也许这会让事情变得清晰。
UDAF APIs
工作DataFrame
Columns
。您正在做的所有操作都必须像Columns
中的所有其他DataFrame
一样进行序列化。如您所知,唯一的支持MapType
是不可变的,因为这是您可以放在Column
中的唯一内容。使用不可变集合,您只需创建一个包含旧集合和值的新集合:
var map = Map[Long,Long]()
map = map + (0L -> 1234L)
map = map + (1L -> 4567L)
是的,就像使用任何DataFrame
一样,您的类型必须匹配。 buffer.getInt(0)
当LongType
确实存在问题时,请null
。
没有标准方法可以重置缓冲区,因为除了对数据类型/用例有意义的其他方法。也许零实际上是上个月的balanace;也许零是另一个数据集的运行平均值;也许零是merge
或空字符串,或者可能为零实际上为零。
update
是一种优化,只有在某些情况下才会发生,如果我没记错的话 - 如果环境允许,可以使用SQL优化可以使用的一种方法。我只使用我用于case class
的相同功能。
dataType
将自动转换为适当的架构,因此对于奖励问题,答案是肯定的,因为架构匹配。将mybatis-mapper.xml
更改为不匹配,您将收到错误。