PySpark:基于观察序列的组类型编号组

时间:2019-04-19 21:33:58

标签: group-by pyspark

我正在尝试根据其到达顺序来确定批次类型的顺序。

我从这个DataFrame开始

+--------+-----+
|sequence|batch|
+————————+—————+
|       1|    a|
|       2|    a|
|       3|    a|
|       4|    a|
|       5|    b|
|       6|    b|
|       7|    b|
|       8|    a|
|       9|    a|
|      10|    a|
|      11|    c|
|      12|    c|
|      13|    c|
|      14|    c|
+———————-+---——+

我想要做的是按到达的顺序识别批次,如下所示。

+--------+-----++----------+
|sequence|batch|batch_order|
+————————+—————+————------—+
|       1|    a|          1|
|       2|    a|          1|
|       3|    a|          1|
|       4|    a|          1|
|       5|    b|          2|
|       6|    b|          2|
|       7|    b|          2|
|       8|    a|          3|
|       9|    a|          3|
|      10|    a|          3|
|      11|    c|          4|
|      12|    c|          4|
|      13|    c|          4|
|      14|    c|          4|
+———————-+---——+————-------+

当我按批次分组时,所有A型批次都被分组在一起。我需要按顺序到达的子批次。

这是创建测试数据的示例代码。

from pyspark.sql import SparkSession
from pyspark.sql.types import  IntegerType
import pyspark.sql.functions as F
from pyspark.sql import Window


spark = SparkSession.builder.appName('test').master("local[*]").getOrCreate()

df = spark.createDataFrame([[1, 'a'],
 [2, 'a'],
 [3, 'a'],
 [4, 'a'],
 [5, 'b'],
 [6, 'b'],
 [7, 'b'],
 [8, 'a'],
 [9, 'a'],
 [10, 'a'],
 [11, 'c'],
 [12, 'c'],
 [13, 'c'],
 [14, 'c']], schema=['order', 'batch'])
df = df.withColumn('order', F.col("order").cast(IntegerType()))


我尝试了此窗口,但它按批处理类型而不是按批处理顺序分组。

df1 = df.withColumn("row_num", F.row_number().over(Window.partitionBy("batch").orderBy("order")))

df1.show()

+-----+-----+-------+
|order|batch|row_num|
+-----+-----+-------+
|   11|    c|      1|
|   12|    c|      2|
|   13|    c|      3|
|   14|    c|      4|
|    5|    b|      1|
|    6|    b|      2|
|    7|    b|      3|
|    1|    a|      1|
|    2|    a|      2|
|    3|    a|      3|
|    4|    a|      4|
|    8|    a|      5|
|    9|    a|      6|
|   10|    a|      7|
+-----+-----+-------+

1 个答案:

答案 0 :(得分:0)

一种方法是使用lag()窗口函数检索先前的batch值,然后将其与当前batch进行比较,使用此标志进行累加和。

from pyspark.sql import functions as F, Window

# set up the Window Spec
# note: partitionBy(F.lit(0)) just to bypass the WARN message
win = Window.partitionBy(F.lit(0)).orderBy('sequence')

# get the prev_batch, 
# set up the flag based on: batch == prev_batch ? 0 : 1
# batch_order is the running sum with the column-flag
df.withColumn('prev_batch', F.lag('batch').over(win)) \
  .withColumn('flag', F.when(F.col('batch') == F.col('prev_batch'),0).otherwise(1)) \
  .withColumn('batch_order', F.sum('flag').over(win)) \
  .drop('prev_batch', 'flag') \
  .sort('sequence') \
  .show()
#+--------+-----+-----------+
#|sequence|batch|batch_order|
#+--------+-----+-----------+
#|       1|    a|          1|
#|       2|    a|          1|
#|       3|    a|          1|
#|       4|    a|          1|
#|       5|    b|          2|
#|       6|    b|          2|
#|       7|    b|          2|
#|       8|    a|          3|
#|       9|    a|          3|
#|      10|    a|          3|
#|      11|    c|          4|
#|      12|    c|          4|
#|      13|    c|          4|
#|      14|    c|          4|
#+--------+-----+-----------+