来自列表类型的两个数据框列的pyspark交集

时间:2018-08-28 16:50:34

标签: list dataframe join pyspark intersection

我有一个带有位置列的数据框,每个单元格都包含country_name列表,我想从这两个列中找到通用的country_name并将其添加到输出数据框中。....在pyspark中编码..... < / p>

df_input = spark.createDataFrame([
(100001,12301, 'India', ['India', 'USA','Germany']), (100002, 12302, 
'Germany', ['India', 'UK','Germany']), 
(100003,12303,'Taiwan',['India','Japan','China'])], ("pos_id","emp_id", 
"e_location", "p_location"))

数据帧输入:

+------+------+----------+--------------------+
|pos_id|emp_id|e_location|          p_location|
+------+------+----------+--------------------+
|100001| 12301|     India|[India, USA, Germ...|
|100002| 12302|   Germany|[India, UK, Germany]|
|100003| 12303|    Taiwan|[India, Japan, Ch...|
+------+------+----------+--------------------+

现在我想要输出DF中显示的两者之间的交集。

输出数据框

+------+---------+----------------+
|emp_id|   pos_id| matched_country|
+------+---------+------+---------+
| 12301|   100001|           India|
| 12302|   100002|         Germany|
| 12303|   100003|            None|
+------+---------+----------------+

1 个答案:

答案 0 :(得分:0)

我假设您的df_ploc数据框在p_location列中包含国家/地区列表。然后,您可以使用类似的方法来创建交集并保留pos_idemp_id的所有组合。

由于缺少括号和假设使用列表,我修改了初始代码段(否则,您必须使用split方法。

df_ploc = spark.createDataFrame([
(000001, ['India', 'USA','Germany']), (000002, ['India', 'UK','Germany']), 
(000003,['India','Japan','China'])
], ("pos_id", "p_location"))

df_eloc = spark.createDataFrame([
(12301, 'India'), (12302, 'Germany'), (12303,'Taiwan')
], ("emp_id", "e_location"))

#create a new dataframe with one line per country
df_new = df_ploc.select("pos_id",explode("p_location").alias("new_location"))
df_eloc.join(df_new, df_new["new_location"] == df_eloc["e_location"], how="inner").show()

输出如下:

+------+----------+------+------------+
|emp_id|e_location|pos_id|new_location|
+------+----------+------+------------+
| 12302|   Germany|     1|     Germany|
| 12302|   Germany|     2|     Germany|
| 12301|     India|     3|       India|
| 12301|     India|     1|       India|
| 12301|     India|     2|       India|
+------+----------+------+------------+

修改后的Join看起来像

df_eloc.join(df_new, df_new["new_location"] == df_eloc["e_location"], how="left").groupBy("emp_id","new_location").agg(min("pos_id")).show()

输出类似:

+------+------------+-----------+
|emp_id|new_location|min(pos_id)|
+------+------------+-----------+
| 12301|       India|          1|
| 12302|     Germany|          1|
| 12303|        null|       null|
+------+------------+-----------+

如果您的pos_id应该只是一个枚举(例如1,2,3,4,5 ....),则可以使用诸如row_number之类的某些函数来创建此列。 例如:

from pyspark.sql.functions import monotonically_increasing_id
from pyspark.sql.functions import row_number
from pyspark.sql.window import Window as W

df1 = df_eloc.join(df_new, df_new["new_location"] == df_eloc["e_location"], how="left").groupBy("emp_id","new_location").agg(min("pos_id"))
df1 = df1.withColumn("idx", monotonically_increasing_id())
windowSpec = W.orderBy("idx")
df1.withColumn("pos_id", row_number().over(windowSpec)).select("emp_id","pos_id","new_location").show()

输出:

+------+------+------------+
|emp_id|pos_id|new_location|
+------+------+------------+
| 12301|     1|       India|
| 12302|     2|     Germany|
| 12303|     3|        null|
+------+------+------------+