在PySpark Dataframe中对连续的行进行分组

时间:2018-07-12 15:39:51

标签: python pyspark

我有以下示例Spark DataFrame:

rdd = sc.parallelize([(1,"19:00:00", "19:30:00", 30), (1,"19:30:00", "19:40:00", 10),(1,"19:40:00", "19:43:00", 3), (2,"20:00:00", "20:10:00", 10), (1,"20:05:00", "20:15:00", 10),(1,"20:15:00", "20:35:00", 20)])
df = spark.createDataFrame(rdd, ["user_id", "start_time", "end_time", "duration"])
df.show()

+-------+----------+--------+--------+
|user_id|start_time|end_time|duration|
+-------+----------+--------+--------+
|      1|  19:00:00|19:30:00|      30|
|      1|  19:30:00|19:40:00|      10|
|      1|  19:40:00|19:43:00|       3|
|      2|  20:00:00|20:10:00|      10|
|      1|  20:05:00|20:15:00|      10|
|      1|  20:15:00|20:35:00|      20|
+-------+----------+--------+--------+

我想根据开始时间和结束时间对连续的行进行分组。例如,对于相同的user_id,如果某行的开始时间与上一行的结束时间相同,我想将它们分组在一起并求和持续时间。

所需的结果是:

+-------+----------+--------+--------+
|user_id|start_time|end_time|duration|
+-------+----------+--------+--------+
|      1|  19:00:00|19:43:00|      43|
|      2|  20:00:00|20:10:00|      10|
|      1|  20:05:00|20:35:00|      30|
+-------+----------+--------+--------+

数据帧的前三行被分组在一起,因为它们都对应于user_id 1,并且开始时间和结束时间形成了连续的时间轴。

这是我最初的方法:

使用滞后功能获取下一个开始时间:

from pyspark.sql.functions import *
from pyspark.sql import Window
import sys
# compute next start time 
window = Window.partitionBy('user_id').orderBy('start_time')
df = df.withColumn("next_start_time", lag(df.start_time, -1).over(window))

df.show()

+-------+----------+--------+--------+---------------+
|user_id|start_time|end_time|duration|next_start_time|
+-------+----------+--------+--------+---------------+
|      1|  19:00:00|19:30:00|      30|       19:30:00|
|      1|  19:30:00|19:40:00|      10|       19:40:00|
|      1|  19:40:00|19:43:00|       3|       20:05:00|
|      1|  20:05:00|20:15:00|      10|       20:15:00|
|      1|  20:15:00|20:35:00|      20|           null|
|      2|  20:00:00|20:10:00|      10|           null|
+-------+----------+--------+--------+---------------+

获取当前行的结束时间与下一行的开始时间之差:

time_fmt = "HH:mm:ss"
timeDiff = unix_timestamp('next_start_time', format=time_fmt) - unix_timestamp('end_time', format=time_fmt) 

df = df.withColumn("difference", timeDiff)
df.show()

+-------+----------+--------+--------+---------------+----------+
|user_id|start_time|end_time|duration|next_start_time|difference|
+-------+----------+--------+--------+---------------+----------+
|      1|  19:00:00|19:30:00|      30|       19:30:00|         0|
|      1|  19:30:00|19:40:00|      10|       19:40:00|         0|
|      1|  19:40:00|19:43:00|       3|       20:05:00|      1320|
|      1|  20:05:00|20:15:00|      10|       20:15:00|         0|
|      1|  20:15:00|20:35:00|      20|           null|      null|
|      2|  20:00:00|20:10:00|      10|           null|      null|
+-------+----------+--------+--------+---------------+----------+

现在我的想法是使用带有窗口的sum函数来获取持续时间的累积和,然后执行groupBy。但是我的方法有很多缺陷。

2 个答案:

答案 0 :(得分:4)

这是一种方法:

将行聚集在一起成组,其中一组是具有相同user_id且连续的行集合(start_time与先前的end_time匹配)。然后,您可以使用此group进行汇总。

到达此处的一种方法是创建中间指示器列,以告诉您用户是否已更改或时间不是连续的。然后在指标列上执行累加总和以创建group

例如:

import pyspark.sql.functions as f
from pyspark.sql import Window

w1 = Window.orderBy("start_time")
df = df.withColumn(
        "userChange",
        (f.col("user_id") != f.lag("user_id").over(w1)).cast("int")
    )\
    .withColumn(
        "timeChange",
        (f.col("start_time") != f.lag("end_time").over(w1)).cast("int")
    )\
    .fillna(
        0,
        subset=["userChange", "timeChange"]
    )\
    .withColumn(
        "indicator",
        (~((f.col("userChange") == 0) & (f.col("timeChange")==0))).cast("int")
    )\
    .withColumn(
        "group",
        f.sum(f.col("indicator")).over(w1.rangeBetween(Window.unboundedPreceding, 0))
    )
df.show()
#+-------+----------+--------+--------+----------+----------+---------+-----+
#|user_id|start_time|end_time|duration|userChange|timeChange|indicator|group|
#+-------+----------+--------+--------+----------+----------+---------+-----+
#|      1|  19:00:00|19:30:00|      30|         0|         0|        0|    0|
#|      1|  19:30:00|19:40:00|      10|         0|         0|        0|    0|
#|      1|  19:40:00|19:43:00|       3|         0|         0|        0|    0|
#|      2|  20:00:00|20:10:00|      10|         1|         1|        1|    1|
#|      1|  20:05:00|20:15:00|      10|         1|         1|        1|    2|
#|      1|  20:15:00|20:35:00|      20|         0|         0|        0|    2|
#+-------+----------+--------+--------+----------+----------+---------+-----+

现在我们有了group列,我们可以进行如下汇总以获得所需的结果:

df.groupBy("user_id", "group")\
    .agg(
        f.min("start_time").alias("start_time"),
        f.max("end_time").alias("end_time"),
        f.sum("duration").alias("duration")
    )\
    .drop("group")\
    .show()
#+-------+----------+--------+--------+
#|user_id|start_time|end_time|duration|
#+-------+----------+--------+--------+
#|      1|  19:00:00|19:43:00|      43|
#|      1|  20:05:00|20:35:00|      30|
#|      2|  20:00:00|20:10:00|      10|
#+-------+----------+--------+--------+

答案 1 :(得分:0)

这是根据Pault的回答得出的可行解决方案:

创建数据框:

rdd = sc.parallelize([(1,"19:00:00", "19:30:00", 30), (1,"19:30:00", "19:40:00", 10),(1,"19:40:00", "19:43:00", 3), (2,"20:00:00", "20:10:00", 10), (1,"20:05:00", "20:15:00", 10),(1,"20:15:00", "20:35:00", 20)])

df = spark.createDataFrame(rdd, ["user_id", "start_time", "end_time", "duration"])

df.show()

+-------+----------+--------+--------+
|user_id|start_time|end_time|duration|
+-------+----------+--------+--------+
|      1|  19:00:00|19:30:00|      30|
|      1|  19:30:00|19:40:00|      10|
|      1|  19:40:00|19:43:00|       3|
|      1|  20:05:00|20:15:00|      10|
|      1|  20:15:00|20:35:00|      20|
+-------+----------+--------+--------+

创建一个指示符列,以指示时间更改的时间,并使用累积总和为每个组赋予唯一的ID:

import pyspark.sql.functions as f
from pyspark.sql import Window

w1 =  Window.partitionBy('user_id').orderBy('start_time')
df = df.withColumn(
        "indicator",
        (f.col("start_time") != f.lag("end_time").over(w1)).cast("int")
    )\
    .fillna(
        0,
        subset=[ "indicator"]
    )\
    .withColumn(
        "group",
        f.sum(f.col("indicator")).over(w1.rangeBetween(Window.unboundedPreceding, 0))
    )
df.show()

+-------+----------+--------+--------+---------+-----+
|user_id|start_time|end_time|duration|indicator|group|
+-------+----------+--------+--------+---------+-----+
|      1|  19:00:00|19:30:00|      30|        0|    0|
|      1|  19:30:00|19:40:00|      10|        0|    0|
|      1|  19:40:00|19:43:00|       3|        0|    0|
|      1|  20:05:00|20:15:00|      10|        1|    1|
|      1|  20:15:00|20:35:00|      20|        0|    1|
+-------+----------+--------+--------+---------+-----+

现在按用户ID和组变量进行分组。

+-------+----------+--------+--------+
|user_id|start_time|end_time|duration|
+-------+----------+--------+--------+
|      1|  19:00:00|19:43:00|      43|
|      1|  20:05:00|20:35:00|      30|
+-------+----------+--------+--------+