在每个组中找到n个最小值

时间:2017-05-23 03:24:46

标签: scala apache-spark apache-spark-sql

我有一个包含不同地理位置的数据框以及与其他地方的距离。我的问题是我想为每个地理位置找到最接近的 n 位置。我的第一个想法是使用groupBy()然后进行某种聚合,但我无法使其工作。

相反,我尝试先将数据帧转换为RDD并使用groupByKey(),但它的工作方式很有效,但该方法很麻烦。有没有更好的替代方案来解决这个问题?也许使用groupBy()并以某种方式聚合?

我的方法的一个小例子n=2输入:

+---+--------+
| id|distance|
+---+--------+
|  1|     5.0|
|  1|     3.0|
|  1|     7.0|
|  1|     4.0|
|  2|     1.0|
|  2|     3.0|
|  2|     3.0|
|  2|     7.0|
+---+--------+

代码:

df.rdd.map{case Row(id: Long, distance: Double) => (id, distance)}
  .groupByKey()
  .map{case (id: Long, iter: Iterable[Double]) => (id, iter.toSeq.sorted.take(2))}
  .toDF("id", "distance")
  .withColumn("distance", explode($"distance"))

输出:

+---+--------+
| id|distance|
+---+--------+
|  1|     3.0|
|  1|     4.0|
|  2|     1.0|
|  2|     3.0|
+---+--------+

1 个答案:

答案 0 :(得分:3)

您可以使用以下窗口:

gcloud docker -- push eu.gcr.io/my-project/ubuntu-gcloud

您可以通过替换2来增加所需的结果数量。

希望这有帮助。