简而言之
我有两个数据框的笛卡尔乘积(交叉联接)和函数,该函数为该乘积的给定元素给出一些分数。我现在想为第一个DF的每个成员获取几个第二个DF的“最佳匹配”元素。
详细信息
下面是一个简化的示例,因为我的真实代码在某些地方增加了额外的字段和过滤器。
给出两组数据,每组数据都有一些id和值:
// simple rdds of tuples
val rdd1 = sc.parallelize(Seq(("a", 31),("b", 41),("c", 59),("d", 26),("e",53),("f",58)))
val rdd2 = sc.parallelize(Seq(("z", 16),("y", 18),("x",3),("w",39),("v",98), ("u", 88)))
// convert them to dataframes:
val df1 = spark.createDataFrame(rdd1).toDF("id1", "val1")
val df2 = spark.createDataFrame(rdd2).toDF("id2", "val2")
以及一些函数,它们针对第一和第二数据集中的一对元素给出其“匹配分数”:
def f(a:Int, b:Int):Int = (a * a + b * b * b) % 17
// convert it to udf
val fu = udf((a:Int, b:Int) => f(a, b))
我们可以创建两组的乘积并计算每对的得分:
val dfc = df1.crossJoin(df2)
val r = dfc.withColumn("rez", fu(col("val1"), col("val2")))
r.show
+---+----+---+----+---+
|id1|val1|id2|val2|rez|
+---+----+---+----+---+
| a| 31| z| 16| 8|
| a| 31| y| 18| 10|
| a| 31| x| 3| 2|
| a| 31| w| 39| 15|
| a| 31| v| 98| 13|
| a| 31| u| 88| 2|
| b| 41| z| 16| 14|
| c| 59| z| 16| 12|
...
现在我们要将此结果按id1
分组:
r.groupBy("id1").agg(collect_set(struct("id2", "rez")).as("matches")).show
+---+--------------------+
|id1| matches|
+---+--------------------+
| f|[[v,2], [u,8], [y...|
| e|[[y,5], [z,3], [x...|
| d|[[w,2], [x,6], [v...|
| c|[[w,2], [x,6], [v...|
| b|[[v,2], [u,8], [y...|
| a|[[x,2], [y,10], [...|
+---+--------------------+
但是实际上,我们只希望保留(匹配3个)极少的“匹配”,即得分最高(例如得分最低)的匹配。
问题是
如何将“匹配项”排序并减少到前N个元素?可能是关于collect_list和sort_array的事情,尽管我不知道如何按内部字段排序。
在输入DF较大的情况下是否有办法确保优化-例如汇总时直接选择最小值。我知道,如果我编写的代码没有任何火花,可以轻松完成-为每个id1
保留较小的数组或优先级队列,并在应该添加的位置添加元素,可能会删除一些先前添加的元素。
例如可以肯定的是,交叉联接是一项代价高昂的操作,但是我要避免在结果上浪费内存,而在下一步中将要删除的大部分结果。我的实际用例处理的条目少于100万个DF,因此交叉联接仍然可行,但是由于我们只想为每个id1
选择10-20个顶级匹配,因此似乎很希望不要保留不必要的数据在步骤之间。
答案 0 :(得分:1)
首先,我们只需要前n行。为此,我们通过“ id1”对DF进行分区,并按res对组进行排序。我们使用它向DF添加行号列,就像我们可以使用 where 函数获取前n行。比您可以继续执行您编写的相同代码。按“ id1”分组并收集列表。直到现在,您已经拥有最高的行数。
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
val n = 3
val w = Window.partitionBy($"id1").orderBy($"res".desc)
val res = r.withColumn("rn", row_number.over(w)).where($"rn" <= n).groupBy("id1").agg(collect_set(struct("id2", "res")).as("matches"))
第二个选项可能更好,因为您无需将DF分组两次:
val sortTakeUDF = udf{(xs: Seq[Row], n: Int)} => xs.sortBy(_.getAs[Int]("res")).reverse.take(n).map{case Row(x: String, y:Int)}}
r.groupBy("id1").agg(sortTakeUDF(collect_set(struct("id2", "res")), lit(n)).as("matches"))
在这里,我们创建一个udf,它使用数组列和一个整数值n。 udf按“ res”对数组进行排序,仅返回前n个元素。