使用数据帧的Spark重叠算法

时间:2019-02-04 15:53:11

标签: python apache-spark pyspark apache-spark-sql

给出具有以下字段的数据源:product_id-product-start_time-end_time

我正在尝试使用Dataframe函数为相同的product(基于start_timeend_time)捕获重叠记录的逻辑。

------------------------------------------------
| product_id | product | start_time | end_time |
------------------------------------------------
|      1     | bottle  |     2      |    4     |
|      2     | bottle  |     3      |    5     |
|      3     | bottle  |     2      |    3     |
|      4     | bottle  |     6      |    7     |
|      1     |   can   |     2      |    4     |
|      2     |   can   |     5      |    6     |
|      3     |   can   |     2      |    4     |

我想接收输出

-------------------------------------------------------------------------------------------------
| product_id_a | product_id_b | product | start_time_a | end_time_a | start_time_b | end_time_b |
-------------------------------------------------------------------------------------------------
|       1      |       2      | bottle  |      2       |     4      |      3       |     5      |
|       1      |       3      | bottle  |      2       |     4      |      2       |     3      |

由于bottle_1bottle_2bottle_3的时间重叠,如果满足以下条件,其中2条记录将重叠:

  • max(a.start_time, b.start_time) < min(a.end_time, b.end_time)
  • !(a.start_time == b.start_time && a.end_time == b.end_time)
  • a.start_time != b.start_time || a.end_time != b.end_time

最后2个条件仅表明我对start_timeend_time相等(例如,can_1can_3不在期望中的情况不感兴趣)即使它们具有相同的start_timeend_time)。

对于如何构造问题很容易想到使用RDD的MapReduce解决方案,但是我对使用数据框的解决方案感兴趣。

提示:是否有可能使用groupBy().agg()指定达到所描述逻辑的有趣条件?

有关任何进一步的解释,请随时询问

不重复How to aggregate over rolling time window with groups in Spark

不幸的是,在报告的答案中使用了F.lag,在我的情况下这还不够好:F.lag仅使用与先前记录的比较,但是在报告的示例中无法按预期工作因为该bottle_1不会被报告为与bottle_3重叠,因为它们不是连续记录

3 个答案:

答案 0 :(得分:2)

每个条件都可以直接转换为SQL

with wp as (
      select t.*,
             lag(id) over (partition by id order by updated_date) as prev_id,
             min(case when type = 'W' then id ) over (partition by id) as min_w_id,
             row_number() over (partition by id order by updated_date) as seqnum
      from t
      where type in ('W', 'P')
     )
select
from wp
where (min_w_id is null = 0 and seqnum = 1) or
      (prev_id = min_w_id);

因此您可以加入并过滤。

from pyspark.sql.functions import col, least, greatest

cond1 = (
    greatest(col("a.start_time"), col("b.start_time")) < 
    least(col("a.end_time"), col("b.end_time"))
)

cond2 = ~(
    (col("a.start_time") == col("b.start_time")) & 
    (col("a.end_time") == col("b.end_time"))
)

cond3 = (
    (col("a.start_time") != col("b.start_time")) | 
    (col("a.end_time") != col("b.end_time"))
)

答案 1 :(得分:1)

基于@Andronicus solution,我在纯Python中提出了这种方法。

有必要将DataFrame与他自己连接起来以检查行是否重叠。当然,您需要省略条件df.product_id < duplicate_df.product_id的自身(两个相同的Row和相反的product_id重叠)。

整个代码:

from pyspark.sql import functions as F

df = spark.createDataFrame(
    [(1, "bottle", 2, 4),
     (2, "bottle", 3, 5),
     (3, "bottle", 2, 3),
     (4, "bottle", 6, 7),
     (1, "can", 2, 4),
     (2, "can", 5, 6),
     (3, "can", 2, 4)], 
     ['product_id', 'product', 'start_time', 'end_time'])

duplicate_df = df

conditions = [df.product == duplicate_df.product,
              df.product_id < duplicate_df.product_id,
              df.start_time != duplicate_df.start_time, 
              df.end_time != duplicate_df.end_time,
              F.least(df.end_time, duplicate_df.end_time) >
              F.greatest(df.start_time, duplicate_df.start_time)]

df.join(duplicate_df, conditions)

答案 2 :(得分:0)

尝试一下:

df.join(cloneDf, $"label").where($"label" !== $"label1").where($"min" < $"max1").where($"min1" < $"max").show()

您需要制作DataFrame的笛卡尔积才能检查,如果行重叠,则可以根据需要映射它们。当然,您需要省略self-两个相同的Row重叠。

整个代码:

val df = SparkEmbedded.ss.createDataFrame(Seq(
  (1, 2, 5),
  (2, 4, 7),
  (3, 6, 9)
)).toDF("product_id", "min", "max")
import SparkEmbedded.ss.implicits._
val cloneDf = df.select(df.columns.map(col):_*)
    .withColumnRenamed("product_id", "product_id1")
    .withColumnRenamed("min", "min1")
    .withColumnRenamed("max", "max1")
df.crossJoin(cloneDf)
  .where($"product_id" < $"product_id1")
  .where($"min" < $"max1")
  .where($"min1" < $"max").show()

为清楚起见,我已经拆分了where子句。

结果是:

+-----+---+---+------+----+----+
|label|min|max|label1|min1|max1|
+-----+---+---+------+----+----+
|    1|  2|  5|     2|   4|   7|
|    2|  4|  7|     3|   6|   9|
+-----+---+---+------+----+----+

该示例在Scala中使用,但是Python具有类似的API。