pyspark:如果列在不同的行中具有相等的值,则合并两行或更多行

时间:2017-12-28 07:55:16

标签: python python-3.x python-2.7 pyspark spark-graphx

由于数据量很大,我必须使用pyspark将数据帧值(列表)组合在不同的行中。

这样的数据框:

x = sqlContext.createDataFrame([("A", ['1','2','3']),("B", ['4','2','5','6']),("C", ['2','4','9','10']),("D", ['11','12','15','16'])],["index", "num_group"])
+-----+----------------+
|index|       num_group|
+-----+----------------+
|    A|       [1, 2, 3]|
|    B|    [4, 2, 5, 6]|
|    C|   [2, 4, 9, 10]|
|    D|[11, 12, 15, 16]|
+-----+----------------+

我希望按列表合并num_group,它们具有相同的元素: (索引是无意义的值或字符串)

+-------------------------+
|                num_group|
+-------------------------+
|[1, 2, 3, 4, 5, 6, 9, 10]|
|         [11, 12, 15, 16]|
+-------------------------+

我想我可以使用graphframes GraphX来查找连接,并根据不同行中的相等值合并两行或更多行。

有可能吗?我并不是真正理解documents的例子。

非常感谢任何帮助。

1 个答案:

答案 0 :(得分:0)

您不需要使用GraphX库。 您所需要的只是collect_list 中可用的udfexplodepyspark.sql.functions函数以及一些小的python操作。

因此,您要做的第一步是收集 lists列中的所有num_group

from pyspark.sql import functions as F
y = x.select(F.collect_list("num_group").alias("collected"))

应该为dataframe提供

+----------------------------------------------------------------------------------------------------------+
|collected                                                                                                 |
+----------------------------------------------------------------------------------------------------------+
|[WrappedArray(1, 2, 3), WrappedArray(4, 2, 5, 6), WrappedArray(2, 4, 9, 10), WrappedArray(11, 12, 15, 16)]|
+----------------------------------------------------------------------------------------------------------+

下一步是定义一个udf函数来迭代所有收集的列表并检查每个列表中的元素,并根据您的要求创建一个包含合并列表的新列表。

def computation(s):
    finalList = []
    finalList.append(list(str(i) for i in s[0]))
    for index in range(1, len(s)):
        for finals in finalList:
            check = False
            for x in s[index]:
                if x in finals:
                    check = True
                    break
            if check == True:
                finals_1 = finals + list(str(i) for i in s[index])
                finalList.remove(finals)
                finalList.append(sorted(list(set(str(i) for i in finals_1))))
            else:
                finalList.append(list(str(i) for i in s[index]))
    return finalList

from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, StringType
collecting_udf = udf(computation, ArrayType(StringType()))

然后你可以explode函数将最终列表分成不同的行。

from pyspark.sql.functions import explode
y.select(explode(collecting_udf("collected")).alias("num_group"))

您应该有以下输出

+-------------------------+
|num_group                |
+-------------------------+
|[1, 10, 2, 3, 4, 5, 6, 9]|
|[11, 12, 15, 16]         |
+-------------------------+