从列中获取最大值并与列中的每个项目进行比较

时间:2019-09-26 15:33:18

标签: pyspark

我有一个数据框,例如:

id | value | date1       | date2
-------------------------------------
1  | 20    | 2015-09-01  | 2018-03-01 
1  | 30    | 2019-04-04  | 2015-03-02 
1  | 40    | 2014-01-01  | 2016-06-09 
2  | 15    | 2014-01-01  | 2013-06-01 
2  | 25    | 2019-07-18  | 2016-07-07 

,并希望为每个id返回sum(value),其中date1<max(date2)代表该id。在上面的示例中,我们将获得:

id | sum_value 
-----------
1  | 60     
2  | 15 

因为对于id 1,max(date2)2018-03-01,并且第一行和第三行符合条件date1<max(date2),因此该值是20的和。 40

我已经尝试过下面的代码,但不能在max函数之外使用agg

df.withColumn('sum_value',F.when(F.col('date1')<F.max(F.col('date2')), value).otherwise(0))
            .groupby(['id']) 

您有什么建议吗?该表有20亿行,因此我正在寻找除重新加入之外的其他选择。

1 个答案:

答案 0 :(得分:2)

您可以使用Window函数。您的要求的直接翻译将是:

from pyspark.sql.functions import col, max as _max, sum as _sum
from pyspark.sql import Window

df.withColumn("max_date2", _max("date2").over(Window.partitionBy("id")))\
    .where(col("date1") < col("max_date2"))\
    .groupBy("id")\
    .agg(_sum("value").alias("sum_value"))\
    .show()
#+---+---------+
#| id|sum_value|
#+---+---------+
#|  1|     60.0|
#|  2|     15.0|
#+---+---------+