我使用的是Spark 2.2。我有一个关于使用ArrayType
的基本问题。我没有找到可以使用的内置聚合函数。
给定DataFrame
列id
,列values
ArrayType
。
我们希望按ID分组,然后按索引计算平均值。
所以给出以下输入
{"id": 1, "values":[1.0, 3.0]}
{"id": 1, "values":[3.0, 7.0]}
{"id": 2, "values":[2.0, 4.0]}
我们想要这个输出
{"id": 1, "values":[2.0, 5.0]}
{"id": 2, "values":[2.0, 4.0]}
我使用 UDAF 提出解决方案,请参阅下面的代码。
在表现方面是否有更好的方法(例如不使用UDAF)?
val meanByIndex = new UserDefinedAggregateFunction {
override def inputSchema: StructType =
StructType(
StructField("values", ArrayType(DoubleType)) :: Nil
)
override def dataType: DataType = ArrayType(DoubleType)
override def deterministic: Boolean = true
override def update(buffer: MutableAggregationBuffer, row: Row): Unit = {
buffer.update(0, buffer.getAs[Long](0) + 1)
buffer.update(1, sumSeq(buffer.getAs[Seq[Double]](1), row.getAs[Seq[Double]](0))
)
}
override def bufferSchema: StructType =
StructType(
StructField("size", LongType) ::
StructField("sum", ArrayType(DoubleType)) :: Nil
)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getAs[Long](0) + buffer2.getAs[Long](0))
buffer1.update(1, sumSeq(buffer1.getAs[Seq[Double]](1), buffer2.getAs[Seq[Double]](1))
)
}
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0L)
buffer.update(1, Seq.empty[Double])
}
override def evaluate(buffer: Row): Any = {
buffer.getAs[Seq[Double]](1).map(_ / buffer.getAs[Long](0))
}
private def sumSeq(s1: Seq[Double], s2: Seq[Double]) = {
if (s1.isEmpty)
s2
else {
s1.zip(s2).map { case (v1, v2) => v1 + v2 }
}
}
}
[更新]关于@ user6910411的答案,我已经比较了执行计划。
使用UDAF
SortAggregate
+- *Sort [id#1 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(id#1, 200)
+- SortAggregate
+- *Sort [id#1 ASC NULLS FIRST], false, 0
+- *FileScan json
没有UDAF
*HashAggregate
+- Exchange hashpartitioning(id#1, 200)
+- *HashAggregate
+- *FileScan json
结论:没有UDAF的解决方案更好,因为我们不需要对整个数据集进行排序。
答案 0 :(得分:0)
对于固定大小的数组,我不会为UserDefinedAggregateFunction
而烦恼并使用标准聚合:
import org.apache.spark.sql.functions._
val df = Seq(
(1, Seq(1.0, 3.0)),
(1, Seq(3.0, 7.0)),
(2, Seq(2.0, 4.0))
).toDF("id", "values")
df.groupBy("id").agg(array((0 until 2) map (i => avg($"values"(i))): _*))
+---+-------------------------------------+
| id|array(avg(values[0]), avg(values[1]))|
+---+-------------------------------------+
| 1| [2.0, 5.0]|
| 2| [2.0, 4.0]|
+---+-------------------------------------+