加速Spark数据框组

时间:2018-11-06 22:44:37

标签: scala apache-spark group-by aggregate aggregate-functions

我对Spark缺乏经验,因此需要有关groupBy的帮助以及在数据帧上聚合函数的方法。考虑以下数据框:

val df = (Seq((1, "a", "1"),
              (1,"b", "3"),
              (1,"c", "6"),
              (2, "a", "9"),
              (2,"c", "10"),
              (1,"b","8" ),
              (2, "c", "3"),
              (3,"r", "19")).toDF("col1", "col2", "col3"))

df.show()
+----+----+----+
|col1|col2|col3|
+----+----+----+
|   1|   a|   1|
|   1|   b|   3|
|   1|   c|   6|
|   2|   a|   9|
|   2|   c|  10|
|   1|   b|   8|
|   2|   c|   3|
|   3|   r|  19|
+----+----+----+

我需要对col1和col2进行分组,并计算col3的平均值,我可以使用以下方法进行操作:

val col1df = df.groupBy("col1").agg(round(mean("col3"),2).alias("mean_col1"))
val col2df = df.groupBy("col2").agg(round(mean("col3"),2).alias("mean_col2"))

但是,在一个大型数据框上,该数据框具有几百万行和成列的成千上万的唯一元素进行分组,这需要很长时间。此外,我还有许多要分组的列,而且耗时很长,我希望减少这一列。有没有更好的方法来执行groupBy,然后进行聚合?

1 个答案:

答案 0 :(得分:2)

您可以使用Multiple Aggregations中的想法,它可以在一次随机操作中完成所有操作,这是最昂贵的操作。

示例:

val df = (Seq((1, "a", "1"),
(1,"b", "3"),
(1,"c", "6"),
(2, "a", "9"),
(2,"c", "10"),
(1,"b","8" ),
(2, "c", "3"),
(3,"r", "19")).toDF("col1", "col2", "col3"))

df.createOrReplaceTempView("data")

val grpRes = spark.sql("""select grouping_id() as gid, col1, col2, round(mean(col3), 2) as res 
                          from data group by col1, col2 grouping sets ((col1), (col2)) """)

grpRes.show(100, false)

输出:

+---+----+----+----+
|gid|col1|col2|res |
+---+----+----+----+
|1  |3   |null|19.0|
|2  |null|b   |5.5 |
|2  |null|c   |6.33|
|1  |1   |null|4.5 |
|2  |null|a   |5.0 |
|1  |2   |null|7.33|
|2  |null|r   |19.0|
+---+----+----+----+

gid 有点有趣,因为它下面有一些二进制计算。但是,如果您的分组列不能为空,那么您可以使用它来选择正确的分组。

执行计划:

scala> grpRes.explain
== Physical Plan ==
*(2) HashAggregate(keys=[col1#111, col2#112, spark_grouping_id#108], functions=[avg(cast(col3#9 as double))])
+- Exchange hashpartitioning(col1#111, col2#112, spark_grouping_id#108, 200)
   +- *(1) HashAggregate(keys=[col1#111, col2#112, spark_grouping_id#108], functions=[partial_avg(cast(col3#9 as double))])
      +- *(1) Expand [List(col3#9, col1#109, null, 1), List(col3#9, null, col2#110, 2)], [col3#9, col1#111, col2#112, spark_grouping_id#108]
         +- LocalTableScan [col3#9, col1#109, col2#110]

您会看到只有一次Exchange操作,这是昂贵的洗牌。