Spark:在滚动时间窗口中查找每组出现次数最多的值

时间:2018-02-08 15:43:42

标签: apache-spark pyspark

从以下火花数据框开始:

from io import StringIO
import pandas as pd
from pyspark.sql.functions import col


pd_df = pd.read_csv(StringIO("""device_id,read_date,id,count
device_A,2017-08-05,4041,3
device_A,2017-08-06,4041,3
device_A,2017-08-07,4041,4
device_A,2017-08-08,4041,3
device_A,2017-08-09,4041,3
device_A,2017-08-10,4041,1
device_A,2017-08-10,4045,2
device_A,2017-08-11,4045,3
device_A,2017-08-12,4045,3
device_A,2017-08-13,4045,3"""),infer_datetime_format=True, parse_dates=['read_date'])

df = spark.createDataFrame(pd_df).withColumn('read_date', col('read_date').cast('date'))
df.show()

输出:

+--------------+----------+----+-----+
|device_id     | read_date|  id|count|
+--------------+----------+----+-----+
|      device_A|2017-08-05|4041|    3|
|      device_A|2017-08-06|4041|    3|
|      device_A|2017-08-07|4041|    4|
|      device_A|2017-08-08|4041|    3|
|      device_A|2017-08-09|4041|    3|
|      device_A|2017-08-10|4041|    1|
|      device_A|2017-08-10|4045|    2|
|      device_A|2017-08-11|4045|    3|
|      device_A|2017-08-12|4045|    3|
|      device_A|2017-08-13|4045|    3|
+--------------+----------+----+-----+

我想在3天的滚动窗口中找到每个(device_id,read_date)组合的最常见ID。对于时间窗口选择的每组行,我需要通过总计每个id的计数来找到最常见的id,然后返回最高ID。

预期产出:

+--------------+----------+----+
|device_id     | read_date|  id|
+--------------+----------+----+
|      device_A|2017-08-05|4041|
|      device_A|2017-08-06|4041|
|      device_A|2017-08-07|4041|
|      device_A|2017-08-08|4041|
|      device_A|2017-08-09|4041|
|      device_A|2017-08-10|4041|
|      device_A|2017-08-11|4045|
|      device_A|2017-08-12|4045|
|      device_A|2017-08-13|4045|
+--------------+----------+----+

我开始认为这只能使用自定义聚合功能。由于spark 2.3没有出来,我将不得不在Scala中写这个或使用collect_list。我错过了什么吗?

2 个答案:

答案 0 :(得分:2)

添加窗口:

{{1}}

使用Find maximum row per group in Spark DataFrame中的一个解决方案

{{1}}

答案 1 :(得分:0)

我设法找到了一个非常低效的解决方案。希望有人可以发现改进以避免python udf并调用collect_list

from pyspark.sql import Window
from pyspark.sql.functions import col, collect_list, first, udf
from pyspark.sql.types import IntegerType

def top_id(ids, counts):
    c = Counter()
    for cnid, count in zip(ids, counts):
        c[cnid] += count

    return c.most_common(1)[0][0]


rolling_window = 3

days = lambda i: i * 86400

# Define a rolling calculation window based on time
window = (
    Window()
        .partitionBy("device_id")
        .orderBy(col("read_date").cast("timestamp").cast("long"))
        .rangeBetween(-days(rolling_window - 1), 0)
)

# Use window and collect_list to store data matching the window definition on each row
df_collected = df.select(
    'device_id', 'read_date',
    collect_list(col('id')).over(window).alias('ids'),
    collect_list(col('count')).over(window).alias('counts')
)

# Get rid of duplicate rows where necessary
df_grouped = df_collected.groupBy('device_id', 'read_date').agg(
    first('ids').alias('ids'),
    first('counts').alias('counts'),
)

# Register and apply udf to return the most frequently seen id
top_id_udf = udf(top_id, IntegerType())
df_mapped = df_grouped.withColumn('top_id', top_id_udf(col('ids'), col('counts')))

df_mapped.show(truncate=False)

返回:

+---------+----------+------------------------+------------+------+
|device_id|read_date |ids                     |counts      |top_id|
+---------+----------+------------------------+------------+------+
|device_A |2017-08-05|[4041]                  |[3]         |4041  |
|device_A |2017-08-06|[4041, 4041]            |[3, 3]      |4041  |
|device_A |2017-08-07|[4041, 4041, 4041]      |[3, 3, 4]   |4041  |
|device_A |2017-08-08|[4041, 4041, 4041]      |[3, 4, 3]   |4041  |
|device_A |2017-08-09|[4041, 4041, 4041]      |[4, 3, 3]   |4041  |
|device_A |2017-08-10|[4041, 4041, 4041, 4045]|[3, 3, 1, 2]|4041  |
|device_A |2017-08-11|[4041, 4041, 4045, 4045]|[3, 1, 2, 3]|4045  |
|device_A |2017-08-12|[4041, 4045, 4045, 4045]|[1, 2, 3, 3]|4045  |
|device_A |2017-08-13|[4045, 4045, 4045]      |[3, 3, 3]   |4045  |
+---------+----------+------------------------+------------+------+