我有一个数据框,例如:
id | value | date1 | date2
-------------------------------------
1 | 20 | 2015-09-01 | 2018-03-01
1 | 30 | 2019-04-04 | 2015-03-02
1 | 40 | 2014-01-01 | 2016-06-09
2 | 15 | 2014-01-01 | 2013-06-01
2 | 25 | 2019-07-18 | 2016-07-07
,并希望为每个id
返回sum(value)
,其中date1<max(date2)
代表该id
。在上面的示例中,我们将获得:
id | sum_value
-----------
1 | 60
2 | 15
因为对于id
1,max(date2)
是2018-03-01
,并且第一行和第三行符合条件date1<max(date2)
,因此该值是20
的和。 40
。
我已经尝试过下面的代码,但不能在max
函数之外使用agg
。
df.withColumn('sum_value',F.when(F.col('date1')<F.max(F.col('date2')), value).otherwise(0))
.groupby(['id'])
您有什么建议吗?该表有20亿行,因此我正在寻找除重新加入之外的其他选择。
答案 0 :(得分:2)
您可以使用Window
函数。您的要求的直接翻译将是:
from pyspark.sql.functions import col, max as _max, sum as _sum
from pyspark.sql import Window
df.withColumn("max_date2", _max("date2").over(Window.partitionBy("id")))\
.where(col("date1") < col("max_date2"))\
.groupBy("id")\
.agg(_sum("value").alias("sum_value"))\
.show()
#+---+---------+
#| id|sum_value|
#+---+---------+
#| 1| 60.0|
#| 2| 15.0|
#+---+---------+