给出具有以下字段的数据源:product_id
-product
-start_time
-end_time
我正在尝试使用Dataframe函数为相同的product
(基于start_time
和end_time
)捕获重叠记录的逻辑。
------------------------------------------------
| product_id | product | start_time | end_time |
------------------------------------------------
| 1 | bottle | 2 | 4 |
| 2 | bottle | 3 | 5 |
| 3 | bottle | 2 | 3 |
| 4 | bottle | 6 | 7 |
| 1 | can | 2 | 4 |
| 2 | can | 5 | 6 |
| 3 | can | 2 | 4 |
我想接收输出
-------------------------------------------------------------------------------------------------
| product_id_a | product_id_b | product | start_time_a | end_time_a | start_time_b | end_time_b |
-------------------------------------------------------------------------------------------------
| 1 | 2 | bottle | 2 | 4 | 3 | 5 |
| 1 | 3 | bottle | 2 | 4 | 2 | 3 |
由于bottle_1
与bottle_2
和bottle_3
的时间重叠,如果满足以下条件,其中2条记录将重叠:
max(a.start_time, b.start_time) < min(a.end_time, b.end_time)
!(a.start_time == b.start_time && a.end_time == b.end_time)
a.start_time != b.start_time || a.end_time != b.end_time
最后2个条件仅表明我对start_time
和end_time
相等(例如,can_1
和can_3
不在期望中的情况不感兴趣)即使它们具有相同的start_time
和end_time
)。
对于如何构造问题很容易想到使用RDD的MapReduce解决方案,但是我对使用数据框的解决方案感兴趣。
提示:是否有可能使用groupBy().agg()
指定达到所描述逻辑的有趣条件?
有关任何进一步的解释,请随时询问
不重复,How to aggregate over rolling time window with groups in Spark
不幸的是,在报告的答案中使用了F.lag
,在我的情况下这还不够好:F.lag仅使用与先前记录的比较,但是在报告的示例中无法按预期工作因为该bottle_1
不会被报告为与bottle_3
重叠,因为它们不是连续记录
答案 0 :(得分:2)
每个条件都可以直接转换为SQL
with wp as (
select t.*,
lag(id) over (partition by id order by updated_date) as prev_id,
min(case when type = 'W' then id ) over (partition by id) as min_w_id,
row_number() over (partition by id order by updated_date) as seqnum
from t
where type in ('W', 'P')
)
select
from wp
where (min_w_id is null = 0 and seqnum = 1) or
(prev_id = min_w_id);
因此您可以加入并过滤。
from pyspark.sql.functions import col, least, greatest
cond1 = (
greatest(col("a.start_time"), col("b.start_time")) <
least(col("a.end_time"), col("b.end_time"))
)
cond2 = ~(
(col("a.start_time") == col("b.start_time")) &
(col("a.end_time") == col("b.end_time"))
)
cond3 = (
(col("a.start_time") != col("b.start_time")) |
(col("a.end_time") != col("b.end_time"))
)
答案 1 :(得分:1)
基于@Andronicus solution,我在纯Python中提出了这种方法。
有必要将DataFrame
与他自己连接起来以检查行是否重叠。当然,您需要省略条件df.product_id < duplicate_df.product_id
的自身(两个相同的Row
和相反的product_id
重叠)。
整个代码:
from pyspark.sql import functions as F
df = spark.createDataFrame(
[(1, "bottle", 2, 4),
(2, "bottle", 3, 5),
(3, "bottle", 2, 3),
(4, "bottle", 6, 7),
(1, "can", 2, 4),
(2, "can", 5, 6),
(3, "can", 2, 4)],
['product_id', 'product', 'start_time', 'end_time'])
duplicate_df = df
conditions = [df.product == duplicate_df.product,
df.product_id < duplicate_df.product_id,
df.start_time != duplicate_df.start_time,
df.end_time != duplicate_df.end_time,
F.least(df.end_time, duplicate_df.end_time) >
F.greatest(df.start_time, duplicate_df.start_time)]
df.join(duplicate_df, conditions)
答案 2 :(得分:0)
尝试一下:
df.join(cloneDf, $"label").where($"label" !== $"label1").where($"min" < $"max1").where($"min1" < $"max").show()
您需要制作DataFrame
的笛卡尔积才能检查,如果行重叠,则可以根据需要映射它们。当然,您需要省略self-两个相同的Row
重叠。
整个代码:
val df = SparkEmbedded.ss.createDataFrame(Seq(
(1, 2, 5),
(2, 4, 7),
(3, 6, 9)
)).toDF("product_id", "min", "max")
import SparkEmbedded.ss.implicits._
val cloneDf = df.select(df.columns.map(col):_*)
.withColumnRenamed("product_id", "product_id1")
.withColumnRenamed("min", "min1")
.withColumnRenamed("max", "max1")
df.crossJoin(cloneDf)
.where($"product_id" < $"product_id1")
.where($"min" < $"max1")
.where($"min1" < $"max").show()
为清楚起见,我已经拆分了where
子句。
结果是:
+-----+---+---+------+----+----+
|label|min|max|label1|min1|max1|
+-----+---+---+------+----+----+
| 1| 2| 5| 2| 4| 7|
| 2| 4| 7| 3| 6| 9|
+-----+---+---+------+----+----+
该示例在Scala中使用,但是Python具有类似的API。