如何为自定义DeclarativeAggregate定义mergeExpressions(在催化剂包中)

时间:2017-03-09 19:21:53

标签: apache-spark

我不理解为非平凡聚合器确定mergeExpressions函数所采用的一般方法。 像org.apache.spark.sql.catalyst.expressions.aggregate.Average这样的mergeExpresssions方法非常简单:

override lazy val mergeExpressions = Seq(
    /* sum = */ sum.left + sum.right,
    /* count = */ count.left + count.right
  )

CentralMomentAgg聚合器的mergeExpressions涉及更多。 我想要做的是创建一个在Sparks CentralMomentAgg之后建模的WeightedStddevSamp聚合器。 我几乎让它工作,但它产生的加权标准偏差仍然与我手工计算的有点偏差。 我无法调试它,因为我不明白如何计算mergeExpressions方法的确切逻辑。 以下是我的代码。 updateExpressions方法基于此weighted incremental algorithm,因此我非常确定该方法是正确的。我相信我的问题在mergeExpressions方法中。任何提示都将不胜感激。

abstract class WeightedCentralMomentAgg(child: Expression, weight: Expression) extends DeclarativeAggregate {

  override def children: Seq[Expression] = Seq(child, weight)
  override def nullable: Boolean = true
  override def dataType: DataType = DoubleType
  override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)

  protected val wSum = AttributeReference("wSum", DoubleType, nullable = false)()
  protected val mean = AttributeReference("mean", DoubleType, nullable = false)()
  protected val s = AttributeReference("s", DoubleType, nullable = false)()
  override val aggBufferAttributes = Seq(wSum, mean, s)
  override val initialValues: Seq[Expression] = Array.fill(3)(Literal(0.0))

  // See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm
  override val updateExpressions: Seq[Expression] = {

    val newWSum = wSum + weight
    val newMean = mean + (weight / newWSum) * (child - mean)
    val newS = s + weight * (child - mean) * (child - newMean)

    Seq(
      If(IsNull(child), wSum, newWSum),
      If(IsNull(child), mean, newMean),
      If(IsNull(child), s, newS)
    )
  }

  override val mergeExpressions: Seq[Expression] = {
    val wSum1 = wSum.left
    val wSum2 = wSum.right
    val newWSum = wSum1 + wSum2
    val delta = mean.right - mean.left
    val deltaN = If(newWSum === Literal(0.0), Literal(0.0), delta / newWSum)
    val newMean = mean.left + wSum1 / newWSum * delta                //  ???
    val newS = s.left + s.right + wSum1 * wSum2 * delta * deltaN     //  ???
    Seq(newWSum, newMean, newS)
  }
}


// Compute the weighted sample standard deviation of a column
case class WeightedStddevSamp(child: Expression, weight: Expression)
  extends WeightedCentralMomentAgg(child, weight) {

  override val evaluateExpression: Expression = {
    If(wSum === Literal(0.0), Literal.create(null, DoubleType),
      If(wSum === Literal(1.0), Literal(Double.NaN),
        Sqrt(s / wSum) ) )
  }

  override def prettyName: String = "wtd_stddev_samp"
}

2 个答案:

答案 0 :(得分:2)

对于任何哈希聚合,它分为四个步骤:

1)初始化缓冲区(wSum,mean,s)

2)在分区内,在给定所有输入的情况下更新密钥的缓冲区(为每个输入调用updateExpression)

3)在混洗之后,使用mergeExpression将所有缓冲区合并为相同的密钥。 wSum.left表示左缓冲区中的wSum,wSum.right表示另一个缓冲区中的wSum

4)使用valueExpression

从缓冲区中获取最终结果

答案 1 :(得分:0)

我发现如何为加权标准差编写mergeExpressions函数。我实际上是正确的,但后来在evaluateExpression中使用了总体方差而不是样本方差计算。下面显示的实现给出了与上面相同的结果,但更容易理解。

override val mergeExpressions: Seq[Expression] = {   
    val newN = n.left + n.right
    val wSum1 = wSum.left
    val wSum2 = wSum.right
    val newWSum = wSum1 + wSum2
    val delta = mean.right - mean.left

    val deltaN = If(newWSum === Literal(0.0), Literal(0.0), delta / newWSum)
    val newMean = mean.left + deltaN * wSum2
    val newS =  (((wSum1 * s.left) + (wSum2 * s.right)) / newWSum) + (wSum1 * wSum2 * deltaN * deltaN)

    Seq(newN, newWSum, newMean, newS)
}

以下是一些参考资料

Davies的帖子概述了该方法,但对于许多非平凡的聚合器,我认为mergeExpressions函数可能非常复杂,并且涉及高级数学以确定正确有效的解决方案。幸运的是,在这种情况下,我找到了一个曾经解决过的人。

此解决方案与我手工完成的工作相匹配。重要的是要注意,如果你想要样本方差而不是总体方差,需要稍微修改evaluateExpression(为s /((n-1)* wSum / n))。