将复杂的UDF用于记录组,我认为需要UDF来解决此问题

时间:2019-05-21 17:59:19

标签: apache-spark apache-spark-sql user-defined-functions

我必须找到一家特定商店何时更改其品牌,我需要填充该商品。这应该适用于每个商店。

+------+-----------+---------------+-------------+-------------+
|MTH_ID| store_id  |     brand     |    brndSales|   TotalSales|
+------+-----------+---------------+-------------+-------------+
|201801|      10941|            115|  80890.44900| 135799.66400|
|201712|      10941|            123| 517440.74500| 975893.79000|
|201711|      10941|            99 | 371501.92100| 574223.52300|
|201710|      10941|            115| 552435.57800| 746912.06700|
|201709|      10941|            115|1523492.60700|1871480.06800|
|201708|      10941|            115|1027698.93600|1236544.50900|
|201707|      10941|            33 |1469219.86900|1622949.53000|

输出如下所示

+------+-----------+---------------+-------------+-------------+
|MTH_ID| store_id  |     brand     |    brndSales|   TotalSales|switchdate
+------+-----------+---------------+-------------+-------------+
|201801|      10941|            115|  80890.44900| 135799.66400| 201712
|201712|      10941|            123| 517440.74500| 975893.79000| 201711
|201711|      10941|            99 | 371501.92100| 574223.52300| 201710
|201710|      10941|            115| 552435.57800| 746912.06700| 201707
|201709|      10941|            115|1523492.60700|1871480.06800| 201707
|201708|      10941|            115|1027698.93600|1236544.50900| 201707
|201707|      10941|            33 |1469219.86900|1622949.53000| 201706

我考虑过应用滞后,但是我们需要检查品牌栏中是否有变化。 如果品牌没有变化,我们必须在上次更改时填充。

输入数据

val data = Seq((201801,      10941,            115,  80890.44900, 135799.66400),(201712,      10941,            123, 517440.74500, 975893.79000),(201711,      10941,            99 , 371501.92100, 574223.52300),(201710,      10941,            115, 552435.57800, 746912.06700),(201709,      10941,            115,1523492.60700,1871480.06800),(201708,      10941,            115,1027698.93600,1236544.50900),(201707,      10941,            33 ,1469219.86900,1622949.53000)).toDF("MTH_ID", "store_id" ,"brand" ,"brndSales","TotalSales")

响应的输出

+------+--------+-----+-----------+-----------+---------------+---+----------+
|MTH_ID|store_id|brand|  brndSales| TotalSales|prev_brand_flag|grp|switchdate|
+------+--------+-----+-----------+-----------+---------------+---+----------+
|201801|   10941|  115|  80890.449| 135799.664|              1|  5|    201801|
|201712|   10941|  123| 517440.745|  975893.79|              1|  4|    201712|
|201711|   10941|   99| 371501.921| 574223.523|              1|  3|    201711|
|201710|   10941|  115| 552435.578| 746912.067|              0|  2|    201708|
|201709|   10941|  115|1523492.607|1871480.068|              0|  2|    201708|
|201708|   10941|  115|1027698.936|1236544.509|              1|  2|    201708|
|201707|   10941|   33|1469219.869| 1622949.53|              1|  1|    201707|
+------+--------+-----+-----------+-----------+---------------+---+----------+

是否应该有足够的可用功能

1 个答案:

答案 0 :(得分:1)

PySpark解决方案。

lag与运行中的sum一起使用,以检查值是否从前一行更改,如果是,则增加一个计数器以设置组。分组完成后,便是每个组获取min日期。

w1 = Window.partitionBy(df.store_id).orderBy(df.mth_id)
df = df.withColumn('prev_brand_flag',when(lag(df.brand).over(w1) == df.brand,0).otherwise(1))
df = df.withColumn('grp',sum(df.prev_brand_flag).over(w1))
w2 = Window.partitionBy(df.store_id,df.grp)
res = df.withColumn('switchdate',min(df.mth_id).over(w2))
res.show()

看看中间数据帧的结果,您将了解逻辑的工作原理。