我有2个数据帧,每个数据帧都有一个Array [String]作为其中一列。对于一个数据帧中的每个条目,我需要在另一个数据帧中找出子集(如果有的话)。这里有一个例子:
DF1:
----------------------------------------------------
id : Long | labels : Array[String]
---------------------------------------------------
10 | [label1, label2, label3]
11 | [label4, label5]
12 | [label6, label7]
DF2:
----------------------------------------------------
item : String | labels : Array[String]
---------------------------------------------------
item1 | [label1, label2, label3, label4, label5]
item2 | [label4, label5]
item3 | [label4, label5, label6, label7]
在我描述的子集操作之后,预期的o / p应该是
DF3:
----------------------------------------------------
item : String | id : Long
---------------------------------------------------
item1 | [10, 11]
item2 | [11]
item3 | [11, 12]
保证DF2在DF1中始终具有相应的子集,因此不会有剩余的元素。
有人可以帮忙在这里采取正确的方法吗?看起来对于DF2中的每个元素,我需要扫描DF1并在第二列上进行子集操作(或设置减法),直到找到所有子集并耗尽该行中的标签,同时这样做会累积“id”列表“田野。我如何以紧凑和有效的方式做到这一点?任何帮助是极大的赞赏。实际上,我可能在DF1中有100个元素,在DF2中有1000个元素。
答案 0 :(得分:0)
我不知道有任何方法可以有效地执行此类操作。但是,这里有一个使用UDF
和笛卡尔联接的可能解决方案。
UDF
接受两个序列并检查第一个中的所有字符串是否都存在于第二个序列中:
val matchLabel = udf((array1: Seq[String], array2: Seq[String]) => {
array1.forall{x => array2.contains(x)}
})
要使用笛卡尔连接,需要启用它,因为它的计算成本很高。
val spark = SparkSession.builder.getOrCreate()
spark.conf.set("spark.sql.crossJoin.enabled", true)
使用UDF
将两个数据帧连接在一起。然后,生成的数据框按item
列分组,以收集所有ID的列表。使用与问题相同的DF1
和DF2
:
val DF3 = DF2.join(DF1, matchLabel(DF1("labels"), DF2("labels")))
.groupBy("item")
.agg(collect_list("id").as("id"))
结果如下:
+-----+--------+
| item| id|
+-----+--------+
|item3|[11, 12]|
|item2| [11]|
|item1|[10, 11]|
+-----+--------+