使用Spark DataFrame获取分组后所有组的TopN

时间:2015-11-11 16:44:30

标签: sql scala apache-spark apache-spark-sql

我有一个Spark SQL DataFrame:

user1 item1 rating1
user1 item2 rating2
user1 item3 rating3
user2 item1 rating4
...

如何按用户分组,然后使用Scala从每个组返回TopN个项目?

使用Python的相似代码:

df.groupby("user").apply(the_func_get_TopN)

1 个答案:

答案 0 :(得分:19)

您可以使用rank窗口功能,如下所示

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{rank, desc}

val n: Int = ???

// Window definition
val w = Window.partitionBy($"user").orderBy(desc("rating"))

// Filter
df.withColumn("rank", rank.over(w)).where($"rank" <= n)

如果您不关心关系,那么您可以将rank替换为row_number