Pyspark-带过滤器的分组方式-优化速度

时间:2019-11-06 07:59:51

标签: python pyspark

我有数十亿行要使用Pyspark处理。

数据框如下所示:

category    value    flag
   A          10       1
   A          12       0
   B          15       0
and so on...

我需要运行两个groupby操作:一个在flag == 1的行上运行,另一个在ALL行上运行。目前,我正在这样做:

frame_1 = df.filter(df.flag==1).groupBy('category').agg(F.sum('value').alias('foo1'))
frame_2 = df.groupBy('category').agg(F.sum('value').alias(foo2))
final_frame = frame1.join(frame2,on='category',how='left')

到目前为止,此代码正在运行,但是我的问题是运行速度很慢。 有没有一种方法可以改进此代码的速度,或者这是极限,因为我了解PySpark的惰性评估确实需要一些时间,但是此代码是实现此目的的最佳方法吗?

2 个答案:

答案 0 :(得分:1)

IIUC,您可以避免昂贵的联接并使用一个groupBy来实现。

final_frame_2 = df.groupBy("category").agg(
    F.sum(F.col("value")*F.col("flag")).alias("foo1"),
    F.sum(F.col("value")).alias("foo2"),
)
final_frame_2.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#|       B| 0.0|15.0|
#|       A|10.0|22.0|
#+--------+----+----+

现在比较执行计划:

首先您的方法:

final_frame.explain()
#== Physical Plan ==
#*(5) Project [category#0, foo1#68, foo2#75]
#+- SortMergeJoin [category#0], [category#78], LeftOuter
#   :- *(2) Sort [category#0 ASC NULLS FIRST], false, 0
#   :  +- *(2) HashAggregate(keys=[category#0], functions=[sum(cast(value#1 as double))])
#   :     +- Exchange hashpartitioning(category#0, 200)
#   :        +- *(1) HashAggregate(keys=[category#0], functions=[partial_sum(cast(value#1 as double))])
#   :           +- *(1) Project [category#0, value#1]
#   :              +- *(1) Filter (isnotnull(flag#2) && (cast(flag#2 as int) = 1))
#   :                 +- Scan ExistingRDD[category#0,value#1,flag#2]
#   +- *(4) Sort [category#78 ASC NULLS FIRST], false, 0
#      +- *(4) HashAggregate(keys=[category#78], functions=[sum(cast(value#79 as double))])
#         +- Exchange hashpartitioning(category#78, 200)
#            +- *(3) HashAggregate(keys=[category#78], functions=[partial_sum(cast(value#79 as double))])
#               +- *(3) Project [category#78, value#79]
#                  +- Scan ExistingRDD[category#78,value#79,flag#80]

final_frame_2现在相同:

final_frame_2.explain()
#== Physical Plan ==
#*(2) HashAggregate(keys=[category#0], functions=[sum((cast(value#1 as double) * cast(flag#2 as double))), sum(cast(value#1 as double))])
#+- Exchange hashpartitioning(category#0, 200)
#   +- *(1) HashAggregate(keys=[category#0], functions=[partial_sum((cast(value#1 as double) * cast(flag#2 as double))), partial_sum(cast(value#1 as double))])
#      +- Scan ExistingRDD[category#0,value#1,flag#2]

注意:严格来说,这与您给出的示例输出的完全不完全相同(如下所示),因为您的内部联接将消除所有没有类别的类别flag = 1行。

final_frame.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#|       A|10.0|22.0|
#+--------+----+----+

您可以将聚合添加到总和flag上,并过滤总和为零的那些(如果需要),而对性能的影响很小。

final_frame_3 = df.groupBy("category").agg(
    F.sum(F.col("value")*F.col("flag")).alias("foo1"),
    F.sum(F.col("value")).alias("foo2"),
    F.sum(F.col("flag")).alias("foo3")
).where(F.col("foo3")!=0).drop("foo3")

final_frame_3.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#|       A|10.0|22.0|
#+--------+----+----+

答案 1 :(得分:0)

请注意,联接操作很昂贵。 您可以执行此操作并将标志添加到您的组:

frame_1 = df.groupBy(["category", "flag"]).agg(F.sum('value').alias('foo1'))

如果您有两个以上的标志,并且想执行flag == 1 vs the rest,则:

import pyspark.sql.functions as F
frame_1 = df.withColumn("flag2", F.when(F.col("flag") == 1, 1).otherwise(0))
frame_1 = df.groupBy(["category", "flag2"]).agg(F.sum('value').alias('foo1'))

如果您要对所有行进行groupby申请,只需创建一个新框架,然后在该框架中再次汇总类别:

frame_1 = df.groupBy("category").agg(F.sum('foo1').alias('foo2'))

不可能同时完成这两个步骤,因为实际上存在一个组重叠。