Pyspark多个简单聚合的最佳做法-countif / sumif格式

时间:2019-05-13 12:52:17

标签: pyspark

我对Pyspark还是比较陌生,我正在寻找有关在长数据帧上进行多个简单聚合的最佳方法的建议。

我有一个交易数据框架,其中客户每天有多个交易,我想按客户分组并保留一些变量,例如总和,以及一些变量,例如条件成立的日期。

所以我想为每位客户知道:

  • 他们从A类购买了多少天
  • 他们周末购物了多少天
  • 所有交易的总支出
  • 理想情况下还要加上其他一些东西,例如上个月的交易,最大支出,周末最大支出等。

因此,在Excel术语中基本上有很多“ countifs”或“ sumifs”。

我觉得最好像下面那样分别计算所有这些,然后将它们合并在一起(按照对pyspark sql query : count distinct values with conditions的回答) ),因为我有很多客户,所以加入会很昂贵,而且由于有些客户在任何周末都不会进行交易,因此我认为这需要成为一个加入的伙伴,而不仅仅是一个麻烦:

total_variables = transactions.groupby('cust_id').agg(sum("spend").alias("total_spend")) 
weekend_variables = transactions.where(transactions.weekend_flag == "Y").groupby('cust_id').agg(countDistinct("date").alias("days_txn_on_weekend"))  
catA_variables = transactions.where(transactions.category == "CatA").groupby('cust_id').agg(countDistinct("date").alias("days_txn_cat_a")) 
final_df = total_variables.join(weekend_variables, col('total_variables.id') == col('weekend_variables.id'), 'left') \
                          .join(catA_variables, col('df1.id') == col('catA_variables.id'), 'left')

一种方法是制作部分为空的列,然后对它们调用计数distint或求和,如下所示:

transactions_additional = transactions.withColumn('date_if_weekend',
                                                psf.when(psf.col("weekend_flag") == "Y",
                                                psf.col('date')).otherwise(psf.lit(None)))
                                      .withColumn('date_if_CatA',
                                                psf.when(psf.col("category") == "CatA",
                                                psf.col('date')).otherwise(psf.lit(None)))
final_df = total_variables .groupby('cust_id').agg(psf.countDistinct("date_if_weekend").alias("days_txn_on_weekend"),
                                                   psf.countDistinct("date_if_CatA").alias("days_txn_cat_a"),
                                                   psf.sum("spend").alias("total_spend"))

但这在生成列方面似乎很浪费,并且可能与我最终想要计算的内容失之交臂。

我认为我可以在pyspark-sql中使用countdistinct和case来做到这一点,但我希望有一种使用pyspark语法的更好方法-也许使用以下格式的自定义聚合UDF:

aggregated_df = transactions.groupby('cust_id').agg(<something that returns total spend>,
                                                    <something that returns days purchased cat A>,
                                                    <something that returns days purchased on the weekend>,)

这可能吗?

1 个答案:

答案 0 :(得分:1)

spark pandas_udf函数对于汇总结果非常有用且可读。 这是示例代码,为获得所需的输出,您可以扩展以添加任何其他汇总结果。

import pyspark.sql.functions as F
from pyspark.sql.types import ArrayType,IntegerType,LongType,StructType,StructField,StringType
import pandas as pd

#you can add last month maximum spend, maximum spend on the weekend etc and 
#update agg_data function
agg_schema = StructType(
    [StructField("cust_id", StringType(), True),
     StructField("days_txn_on_weekend", IntegerType(), True),
     StructField("days_txn_cat_a", IntegerType(), True),
     StructField("total_spend", IntegerType(), True)
     ]
)

@F.pandas_udf(agg_schema, F.PandasUDFType.GROUPED_MAP)
def agg_data(pdf):
    days_txn_on_weekend =  pdf.query("weekend_flag == 'Y'")['date'].nunique()
    days_txn_cat_a = pdf.query("category == 'CatA'")['date'].nunique()
    total_spend = pdf['spend'].sum()
    return pd.DataFrame([(pdf.cust_id[0],days_txn_on_weekend,days_txn_cat_a,total_spend)])

transactions = spark.createDataFrame(
    [
    ('cust_1', 'CatA', 20190101, 'N', 10),
    ('cust_1', 'CatA', 20190101, 'N', 20),
    ('cust_1', 'CatA', 20190105, 'Y',40),
    ('cust_1', 'CatA', 20190105, 'Y',10),
    ('cust_1', 'CatA', 20190112, 'Y', 20),
    ('cust_1', 'CatA', 20190113, 'Y', 10),
    ('cust_1', 'CatA', 20190101, 'N',20),
    ('cust_1', 'CatB', 20190105, 'Y', 50),
    ('cust_1', 'CatB', 20190105, 'Y', 50),
    ('cust_2', 'CatA', 20190115, 'N', 10),
    ('cust_2', 'CatA', 20190116, 'N', 20),
    ('cust_2', 'CatA', 20190117, 'N', 40),
    ('cust_2', 'CatA', 20190119, 'Y', 10),
    ('cust_2', 'CatA', 20190119, 'Y', 20),
    ('cust_2', 'CatA', 20190120, 'Y', 10),
    ('cust_3', 'CatB', 20190108, 'N', 10),
    ],
    ['cust_id','category','date','weekend_flag','spend']
)
transactions.groupBy('cust_id').apply(agg_data).show()

结果

+-------+-------------------+--------------+-----------+
|cust_id|days_txn_on_weekend|days_txn_cat_a|total_spend|
+-------+-------------------+--------------+-----------+
| cust_2|                  2|             5|        110|
| cust_3|                  0|             0|         10|
| cust_1|                  3|             4|        230|
+-------+-------------------+--------------+-----------+