Spark自定义聚合:collect_list + UDF vs UDAF

时间:2018-03-15 08:08:35

标签: apache-spark dataframe aggregate-functions user-defined-functions

我经常需要在spark 2.1中对数据帧执行自定义聚合,并使用这两种方法:

  • 使用groupby / collect_list获取单行中的所有值,然后应用UDF来聚合值
  • 编写自定义UDAF(用户定义的聚合函数)

我通常更喜欢第一个选项,因为它比UDAF实现更容易实现和更易读。但我认为第一个选项通常较慢,因为在网络周围发送了更多数据(没有部分聚合),但我的经验表明UDAF通常很慢。那是为什么?

具体示例:计算直方图

数据在蜂巢表中(1E6随机双值)

val df = spark.table("testtable")

def roundToMultiple(d:Double,multiple:Double) = Math.round(d/multiple)*multiple

UDF方法:

val udf_histo = udf((xs:Seq[Double]) => xs.groupBy(x => roundToMultiple(x,0.25)).mapValues(_.size))

df.groupBy().agg(collect_list($"x").as("xs")).select(udf_histo($"xs")).show(false)

+--------------------------------------------------------------------------------+
|UDF(xs)                                                                         |
+--------------------------------------------------------------------------------+
|Map(0.0 -> 125122, 1.0 -> 124772, 0.75 -> 250819, 0.5 -> 248696, 0.25 -> 250591)|
+--------------------------------------------------------------------------------+

UDAF-方法

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

import scala.collection.mutable

class HistoUDAF(binWidth:Double) extends UserDefinedAggregateFunction {

  override def inputSchema: StructType =
    StructType(
      StructField("value", DoubleType) :: Nil
    )

  override def bufferSchema: StructType =
    new StructType()
      .add("histo", MapType(DoubleType, IntegerType))

  override def deterministic: Boolean = true
  override def dataType: DataType = MapType(DoubleType, IntegerType)
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = Map[Double, Int]()
  }

  private def mergeMaps(a: Map[Double, Int], b: Map[Double, Int]) = {
    a ++ b.map { case (k,v) => k -> (v + a.getOrElse(k, 0)) }
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
     val oldBuffer = buffer.getAs[Map[Double, Int]](0)
     val newInput = Map(roundToMultiple(input.getDouble(0),binWidth) -> 1) 
     buffer(0) = mergeMaps(oldBuffer, newInput)
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val a = buffer1.getAs[Map[Double, Int]](0)
    val b = buffer2.getAs[Map[Double, Int]](0)
    buffer1(0) = mergeMaps(a, b)
  }

  override def evaluate(buffer: Row): Any = {
    buffer.getAs[Map[Double, Int]](0)
  }
}

val histo = new HistoUDAF(0.25)

df.groupBy().agg(histo($"x")).show(false)

+--------------------------------------------------------------------------------+
|histoudaf(x)                                                                    |
+--------------------------------------------------------------------------------+
|Map(0.0 -> 125122, 1.0 -> 124772, 0.75 -> 250819, 0.5 -> 248696, 0.25 -> 250591)|
+--------------------------------------------------------------------------------+

我的测试显示collect_list / UDF方法比UDAF方法快约2倍。这是一般规则,还是有些情况下UDAF真的要快得多,并且相当笨拙的实现是合理的?

0 个答案:

没有答案