从以下火花数据框开始:
from io import StringIO
import pandas as pd
from pyspark.sql.functions import col
pd_df = pd.read_csv(StringIO("""device_id,read_date,id,count
device_A,2017-08-05,4041,3
device_A,2017-08-06,4041,3
device_A,2017-08-07,4041,4
device_A,2017-08-08,4041,3
device_A,2017-08-09,4041,3
device_A,2017-08-10,4041,1
device_A,2017-08-10,4045,2
device_A,2017-08-11,4045,3
device_A,2017-08-12,4045,3
device_A,2017-08-13,4045,3"""),infer_datetime_format=True, parse_dates=['read_date'])
df = spark.createDataFrame(pd_df).withColumn('read_date', col('read_date').cast('date'))
df.show()
输出:
+--------------+----------+----+-----+
|device_id | read_date| id|count|
+--------------+----------+----+-----+
| device_A|2017-08-05|4041| 3|
| device_A|2017-08-06|4041| 3|
| device_A|2017-08-07|4041| 4|
| device_A|2017-08-08|4041| 3|
| device_A|2017-08-09|4041| 3|
| device_A|2017-08-10|4041| 1|
| device_A|2017-08-10|4045| 2|
| device_A|2017-08-11|4045| 3|
| device_A|2017-08-12|4045| 3|
| device_A|2017-08-13|4045| 3|
+--------------+----------+----+-----+
我想在3天的滚动窗口中找到每个(device_id,read_date)组合的最常见ID。对于时间窗口选择的每组行,我需要通过总计每个id的计数来找到最常见的id,然后返回最高ID。
预期产出:
+--------------+----------+----+
|device_id | read_date| id|
+--------------+----------+----+
| device_A|2017-08-05|4041|
| device_A|2017-08-06|4041|
| device_A|2017-08-07|4041|
| device_A|2017-08-08|4041|
| device_A|2017-08-09|4041|
| device_A|2017-08-10|4041|
| device_A|2017-08-11|4045|
| device_A|2017-08-12|4045|
| device_A|2017-08-13|4045|
+--------------+----------+----+
我开始认为这只能使用自定义聚合功能。由于spark 2.3没有出来,我将不得不在Scala中写这个或使用collect_list。我错过了什么吗?
答案 0 :(得分:2)
答案 1 :(得分:0)
我设法找到了一个非常低效的解决方案。希望有人可以发现改进以避免python udf并调用collect_list
。
from pyspark.sql import Window
from pyspark.sql.functions import col, collect_list, first, udf
from pyspark.sql.types import IntegerType
def top_id(ids, counts):
c = Counter()
for cnid, count in zip(ids, counts):
c[cnid] += count
return c.most_common(1)[0][0]
rolling_window = 3
days = lambda i: i * 86400
# Define a rolling calculation window based on time
window = (
Window()
.partitionBy("device_id")
.orderBy(col("read_date").cast("timestamp").cast("long"))
.rangeBetween(-days(rolling_window - 1), 0)
)
# Use window and collect_list to store data matching the window definition on each row
df_collected = df.select(
'device_id', 'read_date',
collect_list(col('id')).over(window).alias('ids'),
collect_list(col('count')).over(window).alias('counts')
)
# Get rid of duplicate rows where necessary
df_grouped = df_collected.groupBy('device_id', 'read_date').agg(
first('ids').alias('ids'),
first('counts').alias('counts'),
)
# Register and apply udf to return the most frequently seen id
top_id_udf = udf(top_id, IntegerType())
df_mapped = df_grouped.withColumn('top_id', top_id_udf(col('ids'), col('counts')))
df_mapped.show(truncate=False)
返回:
+---------+----------+------------------------+------------+------+
|device_id|read_date |ids |counts |top_id|
+---------+----------+------------------------+------------+------+
|device_A |2017-08-05|[4041] |[3] |4041 |
|device_A |2017-08-06|[4041, 4041] |[3, 3] |4041 |
|device_A |2017-08-07|[4041, 4041, 4041] |[3, 3, 4] |4041 |
|device_A |2017-08-08|[4041, 4041, 4041] |[3, 4, 3] |4041 |
|device_A |2017-08-09|[4041, 4041, 4041] |[4, 3, 3] |4041 |
|device_A |2017-08-10|[4041, 4041, 4041, 4045]|[3, 3, 1, 2]|4041 |
|device_A |2017-08-11|[4041, 4041, 4045, 4045]|[3, 1, 2, 3]|4045 |
|device_A |2017-08-12|[4041, 4045, 4045, 4045]|[1, 2, 3, 3]|4045 |
|device_A |2017-08-13|[4045, 4045, 4045] |[3, 3, 3] |4045 |
+---------+----------+------------------------+------------+------+