PySpark-在数组列上连接两个数据框(顺序无关紧要)

时间:2019-06-22 21:03:39

标签: apache-spark dataframe join pyspark rdd

我在将两个数据框与包含PySpark中的数组的列连接在一起时遇到问题。如果数组中的元素相同(顺序无所谓),我想加入这些列。

所以,我有一个包含项集及其频率的DataFrame,格式如下:

+--------------------+----+
|               items|freq|
+--------------------+----+
|  [1828545, 1242385]|   4|
|  [1828545, 2032007]|   4|
|           [1137808]|  11|
|           [1209448]|   5|
|             [21002]|   5|
|           [2793224]| 209|
|     [2793224, 8590]|   7|
|[2793224, 8590, 8...|   4|
|[2793224, 8590, 8...|   4|
|[2793224, 8590, 8...|   5|
|[2793224, 8590, 1...|   4|
|  [2793224, 2593971]|  20|
+--------------------+----+

另一个DataFrame包含以下格式的有关用户和项目的信息:

+------------+-------------+--------------------+
|     user_id|   session_id| itemset            |
+------------+-------------+--------------------+
|WLB2T1JWGTHH|0012c5936056e|[1828545, 1242385]  |
|BZTAWYQ70C7N|00783934ea027|[2793224, 8590]     | 
|42L1RJL436ST|00c6821ed171e|[8590, 2793224]     |
|HB348HWSJAOP|00fa9607ead50|[21002]             |
|I9FOENUQL1F1|013f69b45bb58|[21002]             |  
+------------+-------------+--------------------+

现在,如果数组中的元素相同,则我想在项集和项上将这两个数据框连接起来(它们的排序方式无关紧要)。我想要的输出将是:

+------------+-------------+--------------------+----+
|     user_id|   session_id| itemset            |freq|
+------------+-------------+--------------------+----+
|WLB2T1JWGTHH|0012c5936056e|[1828545, 1242385]  |   4|
|BZTAWYQ70C7N|00783934ea027|[2793224, 8590]     |   7|
|42L1RJL436ST|00c6821ed171e|[8590, 2793224]     |   7|
|HB348HWSJAOP|00fa9607ead50|[21002]             |   5|
|I9FOENUQL1F1|013f69b45bb58|[21002]            |   5|  
+------------+-------------+--------------------+----+

我无法在线找到任何解决方案,只能找到将数据帧连接到数组中包含一项的解决方案。

非常感谢! :)

1 个答案:

答案 0 :(得分:0)

join的spark实现可以毫无问题地处理数组列。唯一的问题是,它不会忽略列的顺序。因此,在正确连接之前,需要对连接列进行排序。您可以为此使用sort_array函数。

from pyspark.sql import functions as F

df1 = spark.createDataFrame(
[
(  [1828545, 1242385],   4),
(  [1828545, 2032007],   4),
(           [1137808],  11),
(           [1209448],   5),
(             [21002],   5),
(           [2793224], 209),
(     [2793224, 8590],   7),
([2793224, 8590, 81],   4),
([2793224, 8590, 82],   4),
([2793224, 8590, 83],   5),
([2793224, 8590, 11],   4),
(  [2793224, 2593971],  20)
], ['items','freq'])


df2 = spark.createDataFrame(
[
('WLB2T1JWGTHH','0012c5936056e',[1828545, 1242385]  ),
('BZTAWYQ70C7N','00783934ea027',[2793224, 8590]     ), 
('42L1RJL436ST','00c6821ed171e',[8590, 2793224]     ),
('HB348HWSJAOP','00fa9607ead50',[21002]             ),
('I9FOENUQL1F1','013f69b45bb58',[21002]             ) 
], ['user_id',   'session_id', 'itemset'])

df1 = df1.withColumn('items', F.sort_array('items'))
df2 = df2.withColumnRenamed('itemset', 'items').withColumn('items', F.sort_array('items'))

df1.join(df2, "items").show()

输出:

+------------------+----+------------+-------------+ 
|             items|freq|     user_id|   session_id| 
+------------------+----+------------+-------------+ 
|   [8590, 2793224]|   7|BZTAWYQ70C7N|00783934ea027| 
|   [8590, 2793224]|   7|42L1RJL436ST|00c6821ed171e| 
|[1242385, 1828545]|   4|WLB2T1JWGTHH|0012c5936056e| 
|           [21002]|   5|HB348HWSJAOP|00fa9607ead50| 
|           [21002]|   5|I9FOENUQL1F1|013f69b45bb58| 
+------------------+----+------------+-------------+