PySpark中的复杂数据分组

时间:2019-03-31 05:21:43

标签: apache-spark pyspark

我正在尝试获取具有复杂要求的汇总,并且我想编写通用代码(不与DF中的任何字段值绑定)

使用我当前的代码,我得到了想要的结果,但是我目前不得不提供硬编码的值以完成我的结果。我想编写一个更通用的“功能”代码来实现相同的功能,不需要任何硬编码的值。

输入数据-

ID  Day category    Amount
A11 2   X           914.89
A11 2   X           106.01
A11 2   Y           481.88
A11 2   X           885.56
A11 14  X           733.1
A11 17  Q           694.53
A11 19  Z           15.86
A11 20  Y           99.22
A11 20  S           404.96
A11 24  P           8.28
A11 25  Q           718.22
A11 25  S           314.13
A11 27  Y           599.16
A11 28  P           817.1

场景-每天获取前五天交易的每个ID的统计信息,而不是 包括来自当天统计信息的交易。例如,在第6天, 应该只考虑从第1天到第5天的交易(滚动时间窗口为5天)。 我们需要计算的统计数据是:

•每个帐户在前5天的交易最大交易额

•每个帐户前5天的平均交易金额

•过去每5天的“ X”,“ Z”和“ R”交易类型的总交易价值 帐户

要实现它,我在下面编写了代码-

    tranwindow=Window.partitionBy(“ID").orderBy("Day").rangeBetween(-5,-1)
    outDF=df\
    .withColumn("Maximum",max(col("Amount")).over(tranwindow))\
    .withColumn("Average",avg(col("Amount")).over(tranwindow))\
    .withColumn(“X_TOTAL_VALUE",sum(when(col("category") == “X", col("Amount"))).over(tranwindow))\
    .withColumn(“Z_TOTAL_VALUE",sum(when(col("category") == “Z", col("Amount"))).over(tranwindow))\
    .withColumn(“R_TOTAL_VALUE",sum(when(col("category") == “R", col("Amount"))).over(tranwindow))\
    .select(“ID","Day","Maximum","Average”,"X_TOTAL_VALUE”,"Z_TOTAL_VALUE”,"R_TOTAL_VALUE").orderBy(“ID","Day”)

此代码获取我想要的结果,但是它与类别值(代码中的硬编码)紧密相关。

|accountId|transactionDay|Maximum|           Average|     X_TOTAL_VALUE|     Z_TOTAL_VALUE|     R_TOTAL_VALUE|
|      A11|             2|   null|              null|              null|              null|              null|
|      A11|             2|   null|              null|              null|              null|              null|
|      A11|             2|   null|              null|              null|              null|              null|
|      A11|             2|   null|              null|              null|              null|              null|
|      A11|            14|   null|              null|              null|              null|              null|
|      A11|            17|  733.1|             733.1|             733.1|              null|              null|
|      A11|            19|  733.1|           713.815|             733.1|              null|              null|
|      A11|            20| 694.53|           355.195|              null|             15.86|              null|
|      A11|            20| 694.53|           355.195|              null|             15.86|              null|
|      A11|            24| 404.96|173.34666666666666|              null|             15.86|              null|
|      A11|            25| 404.96|170.81999999999996|              null|              null|              null|
|      A11|            25| 404.96|170.81999999999996|              null|              null|              null|
|      A11|            27| 718.22| 346.8766666666667|              null|              null|              null|
|      A11|            28| 718.22|          409.9475|              null|              null|              null|

我们如何以更通用的方式编写它,是否可以选择Rollup / Cube?

1 个答案:

答案 0 :(得分:0)

我不确定我是否遵循这里的逻辑,因为您的结果似乎与预期不同。无论如何,这是我尝试过的一个例子

from pyspark.sql.types import *
from pyspark.sql import Row

schemaString = 'ID,Day,category,Amount'
fields = [StructField(field_name, StringType(), True) for field_name in schemaString.split(',')]
schema = StructType(fields)

data = [('All',2,"X",914.89),
('All',2,"X",106.01),
('All',2,"Y",481.88),
('All',2,"X",885.56),
('All',14,"X",733.1),
('All',17,"Q",694.53),
('All',19,"Z",15.86),
('All',20,"Y",99.22),
('All',20,"S",404.96),
('All',24,"P",8.28),
('All',25,"Q",718.22),
('All',25,"S",314.13),
('All',27,"Y",599.16),
('All',28,"P",817.1)]
from pyspark.sql import functions as f

df = spark.createDataFrame(sc.parallelize(data), schema)
df = df.withColumn('Day', df['Day'].cast("integer")).withColumn('Amount', df['Amount'].cast("double")).sort('Day','category')

df.show()

+---+---+--------+------+
| ID|Day|category|Amount|
+---+---+--------+------+
|All|  2|       X|914.89|
|All|  2|       X|885.56|
|All|  2|       X|106.01|
|All|  2|       Y|481.88|
|All| 14|       X| 733.1|
|All| 17|       Q|694.53|
|All| 19|       Z| 15.86|
|All| 20|       S|404.96|
|All| 20|       Y| 99.22|
|All| 24|       P|  8.28|
|All| 25|       Q|718.22|
|All| 25|       S|314.13|
|All| 27|       Y|599.16|
|All| 28|       P| 817.1|
+---+---+--------+------+

from pyspark.sql import Window

w1=Window.partitionBy('ID').orderBy('Day').rangeBetween(-5,-1)
maxCol = f.max(df['Amount']).over(w1)
avgCol = f.avg(df['Amount']).over(w2)

outDF1=df.select(df['ID'],df['Day'],df['category'],df['Amount'],maxCol.alias('Max_Amount'), avgCol.alias('Avg_Amount'))
outDF1.show()

+---+---+--------+------+----------+----------+
| ID|Day|category|Amount|Max_Amount|Avg_Amount|
+---+---+--------+------+----------+----------+
|All|  2|       X|914.89|      null|      null|
|All|  2|       X|106.01|      null|      null|
|All|  2|       X|885.56|      null|      null|
|All|  2|       Y|481.88|      null|      null|
|All| 14|       X| 733.1|      null|      null|
|All| 17|       Q|694.53|     733.1|      null|
|All| 19|       Z| 15.86|     733.1|      null|
|All| 20|       S|404.96|    694.53|      null|
|All| 20|       Y| 99.22|    694.53|      null|
|All| 24|       P|  8.28|    404.96|      null|
|All| 25|       S|314.13|    404.96|    404.96|
|All| 25|       Q|718.22|    404.96|      null|
|All| 27|       Y|599.16|    718.22|      null|
|All| 28|       P| 817.1|    718.22|      8.28|
+---+---+--------+------+----------+----------+

w2=Window.partitionBy(['category']).orderBy('Day').rowsBetween(Window.currentRow-6,Window.currentRow-1)
sumCol = f.sum(df['Amount']).over(w2)

outDF2=df.select(df['ID'],df['category'],df['Day'],df['Amount'],sumCol.alias('Sum_Amount'))
outDF2.sort('category','Day').show()
# Sum of previous 5 days in each category
+---+--------+---+------+----------+
| ID|category|Day|Amount|Sum_Amount|
+---+--------+---+------+----------+
|All|       P| 24|  8.28|      null|
|All|       P| 28| 817.1|      8.28|
|All|       Q| 17|694.53|      null|
|All|       Q| 25|718.22|    694.53|
|All|       S| 20|404.96|      null|
|All|       S| 25|314.13|    404.96|
|All|       X|  2|914.89|      null|
|All|       X|  2|885.56|    1020.9|
|All|       X|  2|106.01|    914.89|
|All|       X| 14| 733.1|   1906.46|
|All|       Y|  2|481.88|      null|
|All|       Y| 20| 99.22|    481.88|
|All|       Y| 27|599.16|     581.1|
|All|       Z| 19| 15.86|      null|
+---+--------+---+------+----------+

jdf = outDF2.groupBy(['ID','category','Day']).pivot('category',['X','Y','Z']).agg(f.first(outDF2['Sum_amount']))
jdf.show()

+---+--------+---+-------+------+----+
| ID|category|Day|      X|     Y|   Z|
+---+--------+---+-------+------+----+
|All|       Q| 17|   null|  null|null|
|All|       Q| 25|   null|  null|null|
|All|       Y|  2|   null|  null|null|
|All|       Y| 20|   null|481.88|null|
|All|       Y| 27|   null| 581.1|null|
|All|       Z| 19|   null|  null|null|
|All|       X|  2|   null|  null|null|
|All|       X| 14|1906.46|  null|null|
|All|       S| 20|   null|  null|null|
|All|       S| 25|   null|  null|null|
|All|       P| 24|   null|  null|null|
|All|       P| 28|   null|  null|null|
+---+--------+---+-------+------+----+

现在您可以重新加入outDF1