如何通过数据框中数组列的索引计算平均值

时间:2017-10-01 13:47:20

标签: scala apache-spark dataframe apache-spark-sql

我使用的是Spark 2.2。我有一个关于使用ArrayType的基本问题。我没有找到可以使用的内置聚合函数。

给定DataFrameid,列values ArrayType

我们希望按ID分组,然后按索引计算平均值。

所以给出以下输入

{"id": 1, "values":[1.0, 3.0]}
{"id": 1, "values":[3.0, 7.0]}
{"id": 2, "values":[2.0, 4.0]}

我们想要这个输出

{"id": 1, "values":[2.0, 5.0]}
{"id": 2, "values":[2.0, 4.0]}

我使用 UDAF 提出解决方案,请参阅下面的代码。

在表现方面是否有更好的方法(例如不使用UDAF)?

  val meanByIndex = new UserDefinedAggregateFunction {
    override def inputSchema: StructType =
      StructType(
        StructField("values", ArrayType(DoubleType)) :: Nil
      )

    override def dataType: DataType = ArrayType(DoubleType)

    override def deterministic: Boolean = true

    override def update(buffer: MutableAggregationBuffer, row: Row): Unit = {
      buffer.update(0, buffer.getAs[Long](0) + 1)
      buffer.update(1, sumSeq(buffer.getAs[Seq[Double]](1), row.getAs[Seq[Double]](0))
      )
    }

    override def bufferSchema: StructType =
      StructType(
        StructField("size", LongType) ::
          StructField("sum", ArrayType(DoubleType)) :: Nil
      )

    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1.update(0, buffer1.getAs[Long](0) + buffer2.getAs[Long](0))
      buffer1.update(1, sumSeq(buffer1.getAs[Seq[Double]](1), buffer2.getAs[Seq[Double]](1))
      )
    }

    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer.update(0, 0L)
      buffer.update(1, Seq.empty[Double])
    }

    override def evaluate(buffer: Row): Any = {
      buffer.getAs[Seq[Double]](1).map(_ / buffer.getAs[Long](0))
    }

    private def sumSeq(s1: Seq[Double], s2: Seq[Double]) = {
      if (s1.isEmpty)
        s2
      else {
        s1.zip(s2).map { case (v1, v2) => v1 + v2 }
      }
    }
  }

[更新]关于@ user6910411的答案,我已经比较了执行计划。

使用UDAF

SortAggregate
+- *Sort [id#1 ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(id#1, 200)
      +- SortAggregate
         +- *Sort [id#1 ASC NULLS FIRST], false, 0
            +- *FileScan json

没有UDAF

*HashAggregate
+- Exchange hashpartitioning(id#1, 200)
   +- *HashAggregate
      +- *FileScan json

结论:没有UDAF的解决方案更好,因为我们不需要对整个数据集进行排序。

1 个答案:

答案 0 :(得分:0)

对于固定大小的数组,我不会为UserDefinedAggregateFunction而烦恼并使用标准聚合:

import org.apache.spark.sql.functions._

val df = Seq(
  (1, Seq(1.0, 3.0)),
  (1, Seq(3.0, 7.0)),
  (2, Seq(2.0, 4.0))
).toDF("id", "values")

df.groupBy("id").agg(array((0 until 2) map (i => avg($"values"(i))): _*))
+---+-------------------------------------+
| id|array(avg(values[0]), avg(values[1]))|
+---+-------------------------------------+
|  1|                           [2.0, 5.0]|
|  2|                           [2.0, 4.0]|
+---+-------------------------------------+