我在PySpark中有两个数据帧:df1
+---+-----------------+
|id1| items1|
+---+-----------------+
| 0| [B, C, D, E]|
| 1| [E, A, C]|
| 2| [F, A, E, B]|
| 3| [E, G, A]|
| 4| [A, C, E, B, D]|
+---+-----------------+
和df2
:
+---+-----------------+
|id2| items2|
+---+-----------------+
|001| [B]|
|002| [A]|
|003| [C]|
|004| [E]|
+---+-----------------+
我想在df1
中创建一个新列,以更新其中的值
items1
列,以便它只保留在items2
中df2
中出现的值。结果应如下所示:
+---+-----------------+----------------------+
|id1| items1| items1_updated|
+---+-----------------+----------------------+
| 0| [B, C, D, E]| [B, C, E]|
| 1| [E, A, C]| [E, A, C]|
| 2| [F, A, E, B]| [A, E, B]|
| 3| [E, G, A]| [E, A]|
| 4| [A, C, E, B, D]| [A, C, E, B]|
+---+-----------------+----------------------+
我通常会使用collect()来获取items2
列中所有值的列表,然后使用应用于items1
中每一行的udf来获取交集。但是数据非常大(超过一千万行),因此我无法使用collect()来获取此类列表。有没有一种方法可以同时将数据保持为数据帧格式?还是不使用collect()的其他方式?
答案 0 :(得分:1)
您要做的第一件事是explode
df2.items2
中的值,以便将数组的内容放在单独的行上:
from pyspark.sql.functions import explode
df2 = df2.select(explode("items2").alias("items2"))
df2.show()
#+------+
#|items2|
#+------+
#| B|
#| A|
#| C|
#| E|
#+------+
(这假设df2.items2
中的值是不同的-如果不是,则需要添加df2 = df2.distinct()
。)
选项1 :使用crossJoin
:
现在,您可以将新的crossJoin
df2
返回到df1
,并仅保留df1.items1
包含df2.items2
中元素的行。我们可以使用pyspark.sql.functions.array_contains
和this trick来实现这一点,而我们可以使用use a column value as a parameter。
过滤后,按id1
和items1
分组并使用pyspark.sql.functions.collect_list
from pyspark.sql.functions import expr, collect_list
df1.alias("l").crossJoin(df2.alias("r"))\
.where(expr("array_contains(l.items1, r.items2)"))\
.groupBy("l.id1", "l.items1")\
.agg(collect_list("r.items2").alias("items1_updated"))\
.show()
#+---+---------------+--------------+
#|id1| items1|items1_updated|
#+---+---------------+--------------+
#| 1| [E, A, C]| [A, C, E]|
#| 0| [B, C, D, E]| [B, C, E]|
#| 4|[A, C, E, B, D]| [B, A, C, E]|
#| 3| [E, G, A]| [A, E]|
#| 2| [F, A, E, B]| [B, A, E]|
#+---+---------------+--------------+
选项2 :爆炸df1.items1
并退出联接:
另一个选择是在explode
中items1
的内容df1
并进行左连接。加入后,我们必须进行类似的分组和聚合。之所以有效,是因为collect_list
会忽略不匹配行引入的null
值
df1.withColumn("items1", explode("items1")).alias("l")\
.join(df2.alias("r"), on=expr("l.items1=r.items2"), how="left")\
.groupBy("l.id1")\
.agg(
collect_list("l.items1").alias("items1"),
collect_list("r.items2").alias("items1_updated")
).show()
#+---+---------------+--------------+
#|id1| items1|items1_updated|
#+---+---------------+--------------+
#| 0| [E, B, D, C]| [E, B, C]|
#| 1| [E, C, A]| [E, C, A]|
#| 3| [E, A, G]| [E, A]|
#| 2| [F, E, B, A]| [E, B, A]|
#| 4|[E, B, D, C, A]| [E, B, C, A]|
#+---+---------------+--------------+