由于数据量很大,我必须使用pyspark
将数据帧值(列表)组合在不同的行中。
这样的数据框:
x = sqlContext.createDataFrame([("A", ['1','2','3']),("B", ['4','2','5','6']),("C", ['2','4','9','10']),("D", ['11','12','15','16'])],["index", "num_group"])
+-----+----------------+
|index| num_group|
+-----+----------------+
| A| [1, 2, 3]|
| B| [4, 2, 5, 6]|
| C| [2, 4, 9, 10]|
| D|[11, 12, 15, 16]|
+-----+----------------+
我希望按列表合并num_group
,它们具有相同的元素:
(索引是无意义的值或字符串)
+-------------------------+
| num_group|
+-------------------------+
|[1, 2, 3, 4, 5, 6, 9, 10]|
| [11, 12, 15, 16]|
+-------------------------+
我想我可以使用graphframes GraphX来查找连接,并根据不同行中的相等值合并两行或更多行。
有可能吗?我并不是真正理解documents的例子。
非常感谢任何帮助。
答案 0 :(得分:0)
您不需要使用GraphX库。 您所需要的只是collect_list
中可用的udf
,explode
和pyspark.sql.functions
函数以及一些小的python操作。
因此,您要做的第一步是收集 lists
列中的所有num_group
。
from pyspark.sql import functions as F
y = x.select(F.collect_list("num_group").alias("collected"))
应该为dataframe
提供
+----------------------------------------------------------------------------------------------------------+
|collected |
+----------------------------------------------------------------------------------------------------------+
|[WrappedArray(1, 2, 3), WrappedArray(4, 2, 5, 6), WrappedArray(2, 4, 9, 10), WrappedArray(11, 12, 15, 16)]|
+----------------------------------------------------------------------------------------------------------+
下一步是定义一个udf
函数来迭代所有收集的列表并检查每个列表中的元素,并根据您的要求创建一个包含合并列表的新列表。
def computation(s):
finalList = []
finalList.append(list(str(i) for i in s[0]))
for index in range(1, len(s)):
for finals in finalList:
check = False
for x in s[index]:
if x in finals:
check = True
break
if check == True:
finals_1 = finals + list(str(i) for i in s[index])
finalList.remove(finals)
finalList.append(sorted(list(set(str(i) for i in finals_1))))
else:
finalList.append(list(str(i) for i in s[index]))
return finalList
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, StringType
collecting_udf = udf(computation, ArrayType(StringType()))
然后你可以explode
函数将最终列表分成不同的行。
from pyspark.sql.functions import explode
y.select(explode(collecting_udf("collected")).alias("num_group"))
您应该有以下输出
+-------------------------+
|num_group |
+-------------------------+
|[1, 10, 2, 3, 4, 5, 6, 9]|
|[11, 12, 15, 16] |
+-------------------------+