在Spark组中获得前n个元素?

时间:2017-01-26 08:28:02

标签: scala apache-spark apache-spark-sql

为什么不是Spark sql top n per group

我自己认为这是重复的,只是现在。但事实并非如此。不同之处在于:我还需要事先进行聚合。我相应地编辑下面的问题。所以我需要totalScore作为所有score的总和,score用于在一个组内进行排序然后被解雇。只有每个分组排名最高的元素code才能进入列表,但totalScore应由所有score组成。所以我们不能忽略每个组的一些元素,然后再聚合。我们需要首先聚合并保留所有元素,然后摆脱一些。现在,这可以通过将原始DataFrame分成两部分来分别完成,然后加入。但这听起来效率不高。

我使用dem Schema

获得了一个Spark DataFrame
root
 |-- inputRowID: long (nullable = false)
 |-- score: double (nullable = true)
 |-- code: string (nullable = true)

我想做

val outDF = inDF.
  sort($"inputRowID", $"score".desc).
  groupBy("inputRowID").
  agg(
    sum($"score").as("totalScore"),
    collect_list($"code").as("list"))

使用架构获取outDF

root
 |-- inputRowID: long (nullable = false)
 |-- totalScore: long (nullable = false)
 |-- list: array (nullable = true)
 |    |-- element: string (containsNull = true)

现在我只想保留数组中的第一个n元素。所以我一直在尝试像

这样的东西
outDF.
  map(r => Row(r(0), r(1).take(n)) )

(当然不起作用)。或者,我想过从组中取出frist n元素,比如

val outDF = inDF.
  sort($"inputRowID", $"sorter".desc).
  groupBy("inputRowID").
  agg(take(n)).
  agg(
    collect_list($"code").as("list"))

但据我所知,没有功能。有什么想法吗?

0 个答案:

没有答案