如何计算包含相同值的间隔(行集)的开始/结束?

时间:2019-07-17 08:19:34

标签: apache-spark pyspark apache-spark-sql pyspark-sql

假设我们有一个Spark DataFrame,其外观如下(按time排序):

+------+-------+
| time | value |
+------+-------+
|    1 | A     |
|    2 | A     |
|    3 | A     |
|    4 | B     |
|    5 | B     |
|    6 | A     |
+------+-------+

我想计算每个不间断值序列的开始/结束时间。上述DataFrame的预期输出为:

+-------+-------+-----+
| value | start | end |
+-------+-------+-----+
| A     |     1 |   3 |
| B     |     4 |   5 |
| A     |     6 |   6 |
+-------+-------+-----+

(最后一行的end值也可以是null。)

通过简单的组聚合进行操作:

.groupBy("value")
.agg(
    F.min("time").alias("start"),
    F.max("time").alias("end")
)

没有考虑到同一value可以以多个不同的间隔出现的事实。

1 个答案:

答案 0 :(得分:1)

这个想法是为每个组创建一个标识符,并使用它来分组并计算您的最小和最大时间。

假设df是您的数据框:

from pyspark.sql import functions as F, Window

df = df.withColumn(
    "fg", 
    F.when(
        F.lag('value').over(Window.orderBy("time"))==F.col("value"), 
        0
    ).otherwise(1)
)

df = df.withColumn(
    "rn",     
    F.sum("fg").over(
        Window
        .orderBy("time")
        .rowsBetween(Window.unboundedPreceding, Window.currentRow)
    )
)

从那时起,您将获得一个数据框,其中包含每个连续组的标识符。

df.show()

+----+-----+---+---+                                                            
|time|value| rn| fg|
+----+-----+---+---+
|   1|    A|  1|  1|
|   2|    A|  1|  0|
|   3|    A|  1|  0|
|   4|    B|  2|  1|
|   5|    B|  2|  0|
|   6|    A|  3|  1|
+----+-----+---+---+

然后您只需要进行汇总

df.groupBy(
    'value', 
    "rn"
).agg(
    F.min('time').alias("start"),
    F.max('time').alias("end")
).drop("rn").show()
+-----+-----+---+                                                               
|value|start|end|
+-----+-----+---+
|    A|    1|  3|
|    B|    4|  5|
|    A|    6|  6|
+-----+-----+---+