[Py] Spark SQL:根据不同列

时间:2017-09-24 16:44:09

标签: apache-spark pyspark apache-spark-sql spark-dataframe

我按groupn1n2订购了以下DataFrame

+-----+--+--+------+------+                                        
|group|n1|n2|n1_ptr|n2_ptr|                                
+-----+--+--+------+------+                                        
|    1| 0| 0|     1|     1|                                        
|    1| 1| 1|     2|     2|                                        
|    1| 1| 5|     2|     6|                                        
|    1| 2| 2|     3|     3|                                        
|    1| 2| 6|     3|     7|                                        
|    1| 3| 3|     4|     4|                                        
|    1| 3| 7|  null|  null|                                        
|    1| 4| 4|     5|     5|                                        
|    1| 5| 1|  null|  null|                                        
|    1| 5| 5|  null|  null|                                        
+-----+--+--+------+------+

每行的n1_ptrn2_ptr值指的是排序后面的组中其他行的n1n2值。换句话说,n1_ptrn2_ptr实际上是指向另一行的指针。我想使用这些指针来识别(n1, n2)对的链。例如,给定数据中的链将是:(0,0) - > (1,1) - > (2,2) - > (3,3) - > (4,4) - > (5,5); (1,5) - > (2,6) - > (3,7);和(5,1)

最终目标是将每个链合并为DataFrame中的单行,描述每个链中的最小和最大n1n2值。继续这个例子,这将产生

+-----+------+------+------+------+
|group|n1_min|n2_min|n1_max|n2_max|
+-----+------+------+------+------+       
|    1|     0|     0|     5|     5|
|    1|     1|     5|     3|     7|
|    1|     5|     1|     5|     1| 
+-----+------+------+------+------+

seems这样的udf可能会成功,但我对性能感到担忧。是否有更合理/更有效的方法来解决这个问题?

1 个答案:

答案 0 :(得分:2)

一个好的解决方案是使用graphframeshttps://graphframes.github.io/quick-start.html

首先让我们改变初始数据框的结构:

import pyspark.sql.functions as psf
df = sc.parallelize([[1, 0, 0, 1, 1],[1, 1, 1, 2, 2],[1, 1, 5, 2, 6],
                     [1, 2, 2, 3, 3],[1, 2, 6, 3, 7],[1, 3, 3, 4, 4],
                     [1, 3, 7, None, None],[1, 4, 4, 5, 5],[1, 5, 1, None, None],
                     [1, 5, 5, None, None]]).toDF(["group","n1","n2","n1_ptr","n2_ptr"]).filter("n1_ptr IS NOT NULL")
df = df.select(
    "group",
    psf.struct("n1", "n2").alias("src"), 
    psf.struct(df.n1_ptr.alias("n1"), df.n2_ptr.alias("n2")).alias("dst"))

df我们构建一个顶点和一个边缘数据框:

v = df.select(
    "group", 
    psf.explode(psf.array("src", "dst")).alias("id"))
e = df.drop("group")

下一步是使用graphframes找到所有连接的组件:

from graphframes import *
g = GraphFrame(v, e)
res = g.connectedComponents()

    +-----+-----+------------+
    |group|   id|   component|
    +-----+-----+------------+
    |    1|[0,0]|309237645312|
    |    1|[1,1]|309237645312|
    |    1|[1,1]|309237645312|
    |    1|[2,2]|309237645312|
    |    1|[1,5]| 85899345920|
    |    1|[2,6]| 85899345920|
    |    1|[2,2]|309237645312|
    |    1|[3,3]|309237645312|
    |    1|[2,6]| 85899345920|
    |    1|[3,7]| 85899345920|
    |    1|[3,3]|309237645312|
    |    1|[4,4]|309237645312|
    |    1|[3,7]| 85899345920|
    |    1|[4,4]|309237645312|
    |    1|[5,5]|309237645312|
    |    1|[5,1]|292057776128|
    |    1|[5,5]|309237645312|
    +-----+-----+------------+

现在,由于图表边缘中的关系意味着节点数n1n2单调增加,我们可以简单地按组件汇总并计算minmax

res.groupBy("group", "component").agg(
    psf.min("id").alias("min_id"), 
    psf.max("id").alias("max_id")
)

    +-----+------------+------+------+
    |group|   component|min_id|max_id|
    +-----+------------+------+------+
    |    1|309237645312| [0,0]| [5,5]|
    |    1| 85899345920| [1,5]| [3,7]|
    |    1|292057776128| [5,1]| [5,1]|
    +-----+------------+------+------+