PySpark Groupby并接收特定列

时间:2019-12-03 12:46:50

标签: pyspark apache-spark-sql pyspark-sql

如果我有如下数据框;

ProductId   StoreId Prediction      Index
24524       20      3               19
24524       20      5               20
24524       20      1               21
24524       20      2               22
24524       20      3               23
24524       20      1               24
24524       20      3               25
24524       20      4               26
24524       20      5               27
24524       20      6               28
24524       20      1               29
37654       23      8               9
37654       23      3               10
37654       23      4               11
37654       23      5               12
37654       23      6               13
37654       23      7               14
37654       23      8               15
37654       23      4               16
37654       23      2               17
37654       23      4               18
37654       23      3               19
37654       23      7               20
37654       23      7               21
37654       23      3               22
37654       23      2               23
37654       23      3               24

我想对每种产品和商店的最后7个索引取平均值。

ProductId   StoreId Prediction(Average)
24524       20      3.28                #(This average is include Index 23, 24, 25, 26, 27, 28 and 29) 
37654       23      4.14                #(This average is include Index 18, 19, 20, 21, 22, 23 and 24) 

我应该如何使用groupby?

  

df.groupBy([“ ProductId”,“ StoreId”])。agg({'Prediction':'avg'}))

您能帮我吗?

2 个答案:

答案 0 :(得分:1)

您可以使用 PRECEDING 当前行

>>> df2.registerTempTable("temp")

使用了与last_value不同

>>> sql("select distinct  ProductId,StoreId,last_value(avg_spent_time) over(partition by ProductId,StoreId order by ProductId,StoreId) as result  from (select ProductId,StoreId,avg(Prediction) over(order by ProductId ROWS BETWEEN 6 PRECEDING AND CURRENT ROW) as avg_spent_time from temp) t").show()
+---------+-------+------------------+
|ProductId|StoreId|            result|
+---------+-------+------------------+
|    24524|     20|3.2857142857142856|
|    37654|     23| 4.142857142857143|
+---------+-------+------------------+

答案 1 :(得分:1)

可以通过窗口功能完成:

from pyspark.sql.window import Window
import pyspark.sql.functions as f

# create a Window function
col_list = ['ProductId', 'StoreId']
window = Window.partitionBy([col(x) for x in col_list]).orderBy(df['Index'].desc())

# select last 7 rows per partitions
df = df.select('*', rank().over(window).alias('rank')).filter(col('rank') <= 7).drop('rank')
# calculate average
df.groupBy(["ProductId","StoreId"]).agg(f.avg(f.col("Prediction"))).show()

+---------+-------+------------------+
|ProductId|StoreId|   avg(Prediction)|
+---------+-------+------------------+
|    37654|     23| 4.142857142857143|
|    24524|     20|3.2857142857142856|
+---------+-------+------------------+