Spark-仅分组和汇总几个最小的项目

时间:2019-06-27 11:43:07

标签: scala apache-spark

简而言之

我有两个数据框的笛卡尔乘积(交叉联接)和函数,该函数为该乘积的给定元素给出一些分数。我现在想为第一个DF的每个成员获取几个第二个DF的“最佳匹配”元素。

详细信息

下面是一个简化的示例,因为我的真实代码在某些地方增加了额外的字段和过滤器。

给出两组数据,每组数据都有一些id和值:

// simple rdds of tuples
val rdd1 = sc.parallelize(Seq(("a", 31),("b", 41),("c", 59),("d", 26),("e",53),("f",58)))
val rdd2 = sc.parallelize(Seq(("z", 16),("y", 18),("x",3),("w",39),("v",98), ("u", 88)))

// convert them to dataframes:
val df1 = spark.createDataFrame(rdd1).toDF("id1", "val1")
val df2 = spark.createDataFrame(rdd2).toDF("id2", "val2")

以及一些函数,它们针对第一和第二数据集中的一对元素给出其“匹配分数”:

def f(a:Int, b:Int):Int = (a * a + b * b * b) % 17
// convert it to udf
val fu = udf((a:Int, b:Int) => f(a, b))

我们可以创建两组的乘积并计算每对的得分:

val dfc = df1.crossJoin(df2)
val r = dfc.withColumn("rez", fu(col("val1"), col("val2")))
r.show

+---+----+---+----+---+
|id1|val1|id2|val2|rez|
+---+----+---+----+---+
|  a|  31|  z|  16|  8|
|  a|  31|  y|  18| 10|
|  a|  31|  x|   3|  2|
|  a|  31|  w|  39| 15|
|  a|  31|  v|  98| 13|
|  a|  31|  u|  88|  2|
|  b|  41|  z|  16| 14|
|  c|  59|  z|  16| 12|
...

现在我们要将此结果按id1分组:

r.groupBy("id1").agg(collect_set(struct("id2", "rez")).as("matches")).show

+---+--------------------+
|id1|             matches|
+---+--------------------+
|  f|[[v,2], [u,8], [y...|
|  e|[[y,5], [z,3], [x...|
|  d|[[w,2], [x,6], [v...|
|  c|[[w,2], [x,6], [v...|
|  b|[[v,2], [u,8], [y...|
|  a|[[x,2], [y,10], [...|
+---+--------------------+

但是实际上,我们只希望保留(匹配3个)极少的“匹配”,即得分最高(例如得分最低)的匹配。

问题是

  1. 如何将“匹配项”排序并减少到前N个元素?可能是关于collect_list和sort_array的事情,尽管我不知道如何按内部字段排序。

  2. 在输入DF较大的情况下是否有办法确保优化-例如汇总时直接选择最小值。我知道,如果我编写的代码没有任何火花,可以轻松完成-为每个id1保留较小的数组或优先级队列,并在应该添加的位置添加元素,可能会删除一些先前添加的元素。

例如可以肯定的是,交叉联接是一项代价高昂的操作,但是我要避免在结果上浪费内存,而在下一步中将要删除的大部分结果。我的实际用例处理的条目少于100万个DF,因此交叉联接仍然可行,但是由于我们只想为每个id1选择10-20个顶级匹配,因此似乎很希望不要保留不必要的数据在步骤之间。

1 个答案:

答案 0 :(得分:1)

首先,我们只需要前n行。为此,我们通过“ id1”对DF进行分区,并按res对组进行排序。我们使用它向DF添加行号列,就像我们可以使用 where 函数获取前n行。比您可以继续执行您编写的相同代码。按“ id1”分组并收集列表。直到现在,您已经拥有最高的行数。

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

val n = 3
val w = Window.partitionBy($"id1").orderBy($"res".desc)
val res = r.withColumn("rn", row_number.over(w)).where($"rn" <= n).groupBy("id1").agg(collect_set(struct("id2", "res")).as("matches"))

第二个选项可能更好,因为您无需将DF分组两次:

val sortTakeUDF = udf{(xs: Seq[Row], n: Int)} => xs.sortBy(_.getAs[Int]("res")).reverse.take(n).map{case Row(x: String, y:Int)}}
r.groupBy("id1").agg(sortTakeUDF(collect_set(struct("id2", "res")), lit(n)).as("matches"))

在这里,我们创建一个udf,它使用数组列和一个整数值n。 udf按“ res”对数组进行排序,仅返回前n个元素。