[Py] Spark SQL:多列会话

时间:2017-09-24 19:01:43

标签: apache-spark pyspark apache-spark-sql spark-dataframe

给出一个积极的长i和一个DataFrame

+-----+--+--+                                                          
|group|n1|n2|                                                              
+-----+--+--+                                                              
|    1| 0| 0|                                                              
|    1| 1| 1|                                                              
|    1| 1| 5|                                                              
|    1| 2| 2|                                                              
|    1| 2| 6|                                                              
|    1| 3| 3|                                                              
|    1| 3| 7|                                                              
|    1| 4| 4|                                                              
|    1| 5| 1|                                                              
|    1| 5| 5|                                                              
+-----+--+--+

您在同一group中如何sessionize行,以便对于会话中的每对连续行r1r2r2.n1> r1.n1r2.n2> r1.n2和最高(r2.n1 - r1.n1r2.n2 - r1.n2)< i?请注意,n1n2值可能不唯一,这意味着构成会话的行在DataFrame中可能不是连续的。

例如,给定DataFrame和i = 3的结果将是

+-----+--+--+-------+
|group|n1|n2|session|
+-----+--+--+-------+
|    1| 0| 0|      1|
|    1| 1| 1|      1|
|    1| 1| 5|      2|
|    1| 2| 2|      1|
|    1| 2| 6|      2|
|    1| 3| 3|      1|
|    1| 3| 7|      2|
|    1| 4| 4|      1|
|    1| 5| 1|      3|
|    1| 5| 5|      1|
+-----+--+--+-------+

任何帮助或提示将不胜感激。谢谢!

1 个答案:

答案 0 :(得分:1)

这看起来像是在尝试使用相同的数字标记图表中所有连接的部分。一个好的解决方案是使用graphframeshttps://graphframes.github.io/quick-start.html

从您的数据框:

df = sc.parallelize([[1, 0, 0],[1, 1, 1],[1, 1, 5],[1, 2, 2],[1, 2, 6],
                    [1, 3, 3],[1, 3, 7],[1, 4, 4],[1, 5, 1],[1, 5, 5]]).toDF(["group","n1","n2"])

我们将创建一个包含唯一id s列表的顶点数据框:

import pyspark.sql.functions as psf
v = df.select(psf.struct("n1", "n2").alias("id"), "group")

    +-----+-----+
    |   id|group|
    +-----+-----+
    |[0,0]|    1|
    |[1,1]|    1|
    |[1,5]|    1|
    |[2,2]|    1|
    |[2,6]|    1|
    |[3,3]|    1|
    |[3,7]|    1|
    |[4,4]|    1|
    |[5,1]|    1|
    |[5,5]|    1|
    +-----+-----+

根据您声明的布尔条件定义的边缘数据帧:

i = 3
e = df.alias("r1").join(
    df.alias("r2"), 
    (psf.col("r1.group") == psf.col("r2.group"))
    & (psf.col("r1.n1") < psf.col("r2.n1"))
    & (psf.col("r1.n2") < psf.col("r2.n2"))
    & (psf.greatest(
        psf.col("r2.n1") - psf.col("r1.n1"),
        psf.col("r2.n2") - psf.col("r1.n2")) < i)
).select(psf.struct("r1.n1", "r1.n2").alias("src"), psf.struct("r2.n1", "r2.n2").alias("dst"))

    +-----+-----+
    |  src|  dst|
    +-----+-----+
    |[0,0]|[1,1]|
    |[0,0]|[2,2]|
    |[1,1]|[2,2]|
    |[1,1]|[3,3]|
    |[1,5]|[2,6]|
    |[1,5]|[3,7]|
    |[2,2]|[3,3]|
    |[2,2]|[4,4]|
    |[2,6]|[3,7]|
    |[3,3]|[4,4]|
    |[3,3]|[5,5]|
    |[4,4]|[5,5]|
    +-----+-----+

现在找到所有连接的组件:

from graphframes import *
g = GraphFrame(v, e)
res = g.connectedComponents()

    +-----+-----+------------+
    |   id|group|   component|
    +-----+-----+------------+
    |[0,0]|    1|309237645312|
    |[1,1]|    1|309237645312|
    |[1,5]|    1| 85899345920|
    |[2,2]|    1|309237645312|
    |[2,6]|    1| 85899345920|
    |[3,3]|    1|309237645312|
    |[3,7]|    1| 85899345920|
    |[4,4]|    1|309237645312|
    |[5,1]|    1|292057776128|
    |[5,5]|    1|309237645312|
    +-----+-----+------------+