我不理解为非平凡聚合器确定mergeExpressions函数所采用的一般方法。 像org.apache.spark.sql.catalyst.expressions.aggregate.Average这样的mergeExpresssions方法非常简单:
override lazy val mergeExpressions = Seq(
/* sum = */ sum.left + sum.right,
/* count = */ count.left + count.right
)
CentralMomentAgg聚合器的mergeExpressions涉及更多。 我想要做的是创建一个在Sparks CentralMomentAgg之后建模的WeightedStddevSamp聚合器。 我几乎让它工作,但它产生的加权标准偏差仍然与我手工计算的有点偏差。 我无法调试它,因为我不明白如何计算mergeExpressions方法的确切逻辑。 以下是我的代码。 updateExpressions方法基于此weighted incremental algorithm,因此我非常确定该方法是正确的。我相信我的问题在mergeExpressions方法中。任何提示都将不胜感激。
abstract class WeightedCentralMomentAgg(child: Expression, weight: Expression) extends DeclarativeAggregate {
override def children: Seq[Expression] = Seq(child, weight)
override def nullable: Boolean = true
override def dataType: DataType = DoubleType
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
protected val wSum = AttributeReference("wSum", DoubleType, nullable = false)()
protected val mean = AttributeReference("mean", DoubleType, nullable = false)()
protected val s = AttributeReference("s", DoubleType, nullable = false)()
override val aggBufferAttributes = Seq(wSum, mean, s)
override val initialValues: Seq[Expression] = Array.fill(3)(Literal(0.0))
// See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm
override val updateExpressions: Seq[Expression] = {
val newWSum = wSum + weight
val newMean = mean + (weight / newWSum) * (child - mean)
val newS = s + weight * (child - mean) * (child - newMean)
Seq(
If(IsNull(child), wSum, newWSum),
If(IsNull(child), mean, newMean),
If(IsNull(child), s, newS)
)
}
override val mergeExpressions: Seq[Expression] = {
val wSum1 = wSum.left
val wSum2 = wSum.right
val newWSum = wSum1 + wSum2
val delta = mean.right - mean.left
val deltaN = If(newWSum === Literal(0.0), Literal(0.0), delta / newWSum)
val newMean = mean.left + wSum1 / newWSum * delta // ???
val newS = s.left + s.right + wSum1 * wSum2 * delta * deltaN // ???
Seq(newWSum, newMean, newS)
}
}
// Compute the weighted sample standard deviation of a column
case class WeightedStddevSamp(child: Expression, weight: Expression)
extends WeightedCentralMomentAgg(child, weight) {
override val evaluateExpression: Expression = {
If(wSum === Literal(0.0), Literal.create(null, DoubleType),
If(wSum === Literal(1.0), Literal(Double.NaN),
Sqrt(s / wSum) ) )
}
override def prettyName: String = "wtd_stddev_samp"
}
答案 0 :(得分:2)
对于任何哈希聚合,它分为四个步骤:
1)初始化缓冲区(wSum,mean,s)
2)在分区内,在给定所有输入的情况下更新密钥的缓冲区(为每个输入调用updateExpression)
3)在混洗之后,使用mergeExpression将所有缓冲区合并为相同的密钥。 wSum.left表示左缓冲区中的wSum,wSum.right表示另一个缓冲区中的wSum
4)使用valueExpression
从缓冲区中获取最终结果答案 1 :(得分:0)
我发现如何为加权标准差编写mergeExpressions函数。我实际上是正确的,但后来在evaluateExpression中使用了总体方差而不是样本方差计算。下面显示的实现给出了与上面相同的结果,但更容易理解。
override val mergeExpressions: Seq[Expression] = {
val newN = n.left + n.right
val wSum1 = wSum.left
val wSum2 = wSum.right
val newWSum = wSum1 + wSum2
val delta = mean.right - mean.left
val deltaN = If(newWSum === Literal(0.0), Literal(0.0), delta / newWSum)
val newMean = mean.left + deltaN * wSum2
val newS = (((wSum1 * s.left) + (wSum2 * s.right)) / newWSum) + (wSum1 * wSum2 * deltaN * deltaN)
Seq(newN, newWSum, newMean, newS)
}
以下是一些参考资料
Davies的帖子概述了该方法,但对于许多非平凡的聚合器,我认为mergeExpressions函数可能非常复杂,并且涉及高级数学以确定正确有效的解决方案。幸运的是,在这种情况下,我找到了一个曾经解决过的人。
此解决方案与我手工完成的工作相匹配。重要的是要注意,如果你想要样本方差而不是总体方差,需要稍微修改evaluateExpression(为s /((n-1)* wSum / n))。