在多个列上应用自定义Spark Aggregator(Spark 2.0)

时间:2017-01-05 11:22:34

标签: apache-spark apache-spark-sql aggregate-functions user-defined-functions

我为Strings创建了一个自定义Aggregator[]

我想将它应用于DataFrame的所有列,其中所有列都是字符串,但列号是任意的。

我坚持写正确的表达方式。我想写这样的东西:

df.agg( df.columns.map( c => myagg(df(c)) ) : _*) 

鉴于各种接口,这显然是错误的。

我查看了RelationalGroupedDataset.agg(expr: Column, exprs: Column*)代码,但我不熟悉表达式操作。

有什么想法吗?

1 个答案:

答案 0 :(得分:7)

与对单个字段(列)进行操作的UserDefinedAggregateFunctions相比,Aggregtors需要完整的Row /值。

如果您想要和Aggregator可以在代码段中使用,则必须按列名称对其进行参数化,并使用Row作为值类型。

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, Row}

case class Max(col: String) 
    extends Aggregator[Row, Int, Int] with Serializable {

  def zero = Int.MinValue
  def reduce(acc: Int, x: Row) =
    Math.max(acc, Option(x.getAs[Int](col)).getOrElse(zero))

  def merge(acc1: Int, acc2: Int) = Math.max(acc1, acc2)
  def finish(acc: Int) = acc

  def bufferEncoder: Encoder[Int] = Encoders.scalaInt
  def outputEncoder: Encoder[Int] = Encoders.scalaInt
}

使用示例:

val df = Seq((1, None, 3), (4, Some(5), -6)).toDF("x", "y", "z")

@transient val exprs = df.columns.map(c => Max(c).toColumn.alias(s"max($c)"))

df.agg(exprs.head, exprs.tail: _*)
+------+------+------+
|max(x)|max(y)|max(z)|
+------+------+------+
|     4|     5|     3|
+------+------+------+

与静态类型Aggregators相比,DatasetsDataset<Row>相比,Seq[_]更有意义。

根据您的要求,您还可以使用Row累加器在一次传递中聚合多个列,并在单个merge调用中处理整个fuse_context(记录)。