我有数十亿行要使用Pyspark处理。
数据框如下所示:
category value flag
A 10 1
A 12 0
B 15 0
and so on...
我需要运行两个groupby操作:一个在flag == 1的行上运行,另一个在ALL行上运行。目前,我正在这样做:
frame_1 = df.filter(df.flag==1).groupBy('category').agg(F.sum('value').alias('foo1'))
frame_2 = df.groupBy('category').agg(F.sum('value').alias(foo2))
final_frame = frame1.join(frame2,on='category',how='left')
到目前为止,此代码正在运行,但是我的问题是运行速度很慢。 有没有一种方法可以改进此代码的速度,或者这是极限,因为我了解PySpark的惰性评估确实需要一些时间,但是此代码是实现此目的的最佳方法吗?
答案 0 :(得分:1)
IIUC,您可以避免昂贵的联接并使用一个groupBy
来实现。
final_frame_2 = df.groupBy("category").agg(
F.sum(F.col("value")*F.col("flag")).alias("foo1"),
F.sum(F.col("value")).alias("foo2"),
)
final_frame_2.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#| B| 0.0|15.0|
#| A|10.0|22.0|
#+--------+----+----+
现在比较执行计划:
首先您的方法:
final_frame.explain()
#== Physical Plan ==
#*(5) Project [category#0, foo1#68, foo2#75]
#+- SortMergeJoin [category#0], [category#78], LeftOuter
# :- *(2) Sort [category#0 ASC NULLS FIRST], false, 0
# : +- *(2) HashAggregate(keys=[category#0], functions=[sum(cast(value#1 as double))])
# : +- Exchange hashpartitioning(category#0, 200)
# : +- *(1) HashAggregate(keys=[category#0], functions=[partial_sum(cast(value#1 as double))])
# : +- *(1) Project [category#0, value#1]
# : +- *(1) Filter (isnotnull(flag#2) && (cast(flag#2 as int) = 1))
# : +- Scan ExistingRDD[category#0,value#1,flag#2]
# +- *(4) Sort [category#78 ASC NULLS FIRST], false, 0
# +- *(4) HashAggregate(keys=[category#78], functions=[sum(cast(value#79 as double))])
# +- Exchange hashpartitioning(category#78, 200)
# +- *(3) HashAggregate(keys=[category#78], functions=[partial_sum(cast(value#79 as double))])
# +- *(3) Project [category#78, value#79]
# +- Scan ExistingRDD[category#78,value#79,flag#80]
final_frame_2
现在相同:
final_frame_2.explain()
#== Physical Plan ==
#*(2) HashAggregate(keys=[category#0], functions=[sum((cast(value#1 as double) * cast(flag#2 as double))), sum(cast(value#1 as double))])
#+- Exchange hashpartitioning(category#0, 200)
# +- *(1) HashAggregate(keys=[category#0], functions=[partial_sum((cast(value#1 as double) * cast(flag#2 as double))), partial_sum(cast(value#1 as double))])
# +- Scan ExistingRDD[category#0,value#1,flag#2]
注意:严格来说,这与您给出的示例输出的完全不完全相同(如下所示),因为您的内部联接将消除所有没有类别的类别flag = 1
行。
final_frame.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#| A|10.0|22.0|
#+--------+----+----+
您可以将聚合添加到总和flag
上,并过滤总和为零的那些(如果需要),而对性能的影响很小。
final_frame_3 = df.groupBy("category").agg(
F.sum(F.col("value")*F.col("flag")).alias("foo1"),
F.sum(F.col("value")).alias("foo2"),
F.sum(F.col("flag")).alias("foo3")
).where(F.col("foo3")!=0).drop("foo3")
final_frame_3.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#| A|10.0|22.0|
#+--------+----+----+
答案 1 :(得分:0)
请注意,联接操作很昂贵。 您可以执行此操作并将标志添加到您的组:
frame_1 = df.groupBy(["category", "flag"]).agg(F.sum('value').alias('foo1'))
如果您有两个以上的标志,并且想执行flag == 1 vs the rest
,则:
import pyspark.sql.functions as F
frame_1 = df.withColumn("flag2", F.when(F.col("flag") == 1, 1).otherwise(0))
frame_1 = df.groupBy(["category", "flag2"]).agg(F.sum('value').alias('foo1'))
如果您要对所有行进行groupby申请,只需创建一个新框架,然后在该框架中再次汇总类别:
frame_1 = df.groupBy("category").agg(F.sum('foo1').alias('foo2'))
不可能同时完成这两个步骤,因为实际上存在一个组重叠。