根据PySpark中另一个数据框的列值更新列中的值

时间:2018-10-22 17:03:02

标签: pyspark intersection collect

我在PySpark中有两个数据帧:df1

+---+-----------------+
|id1|           items1|
+---+-----------------+
|  0|     [B, C, D, E]|
|  1|        [E, A, C]|
|  2|     [F, A, E, B]|
|  3|        [E, G, A]|
|  4|  [A, C, E, B, D]|
+---+-----------------+ 

df2

+---+-----------------+
|id2|           items2|
+---+-----------------+
|001|              [B]|
|002|              [A]|
|003|              [C]|
|004|              [E]|
+---+-----------------+ 

我想在df1中创建一个新列,以更新其中的值 items1列,以便它只保留在items2df2中出现的值。结果应如下所示:

+---+-----------------+----------------------+
|id1|           items1|        items1_updated|
+---+-----------------+----------------------+
|  0|     [B, C, D, E]|             [B, C, E]|
|  1|        [E, A, C]|             [E, A, C]|
|  2|     [F, A, E, B]|             [A, E, B]|
|  3|        [E, G, A]|                [E, A]|
|  4|  [A, C, E, B, D]|          [A, C, E, B]|
+---+-----------------+----------------------+

我通常会使用collect()来获取items2列中所有值的列表,然后使用应用于items1中每一行的udf来获取交集。但是数据非常大(超过一千万行),因此我无法使用collect()来获取此类列表。有没有一种方法可以同时将数据保持为数据帧格式?还是不使用collect()的其他方式?

1 个答案:

答案 0 :(得分:1)

您要做的第一件事是explode df2.items2中的值,以便将数组的内容放在单独的行上:

from pyspark.sql.functions import explode
df2 = df2.select(explode("items2").alias("items2"))
df2.show()
#+------+
#|items2|
#+------+
#|     B|
#|     A|
#|     C|
#|     E|
#+------+

(这假设df2.items2中的值是不同的-如果不是,则需要添加df2 = df2.distinct()。)

选项1 :使用crossJoin

现在,您可以将新的crossJoin df2返回到df1,并仅保留df1.items1包含df2.items2中元素的行。我们可以使用pyspark.sql.functions.array_containsthis trick来实现这一点,而我们可以使用use a column value as a parameter

过滤后,按id1items1分组并使用pyspark.sql.functions.collect_list

进行汇总
from pyspark.sql.functions import expr, collect_list

df1.alias("l").crossJoin(df2.alias("r"))\
    .where(expr("array_contains(l.items1, r.items2)"))\
    .groupBy("l.id1", "l.items1")\
    .agg(collect_list("r.items2").alias("items1_updated"))\
    .show()
#+---+---------------+--------------+
#|id1|         items1|items1_updated|
#+---+---------------+--------------+
#|  1|      [E, A, C]|     [A, C, E]|
#|  0|   [B, C, D, E]|     [B, C, E]|
#|  4|[A, C, E, B, D]|  [B, A, C, E]|
#|  3|      [E, G, A]|        [A, E]|
#|  2|   [F, A, E, B]|     [B, A, E]|
#+---+---------------+--------------+

选项2 :爆炸df1.items1并退出联接:

另一个选择是在explodeitems1的内容df1并进行左连接。加入后,我们必须进行类似的分组和聚合。之所以有效,是因为collect_list会忽略不匹配行引入的null

df1.withColumn("items1", explode("items1")).alias("l")\
    .join(df2.alias("r"), on=expr("l.items1=r.items2"), how="left")\
    .groupBy("l.id1")\
    .agg(
        collect_list("l.items1").alias("items1"),
        collect_list("r.items2").alias("items1_updated")
    ).show()
#+---+---------------+--------------+
#|id1|         items1|items1_updated|
#+---+---------------+--------------+
#|  0|   [E, B, D, C]|     [E, B, C]|
#|  1|      [E, C, A]|     [E, C, A]|
#|  3|      [E, A, G]|        [E, A]|
#|  2|   [F, E, B, A]|     [E, B, A]|
#|  4|[E, B, D, C, A]|  [E, B, C, A]|
#+---+---------------+--------------+