Dataframe上的Spark-SQL窗口函数 - 查找组

时间:2016-02-10 12:52:50

标签: sql apache-spark dataframe apache-spark-sql window-functions

我有以下数据框(比如UserData)。

uid region  timestamp
a   1   1
a   1   2
a   1   3
a   1   4
a   2   5
a   2   6
a   2   7
a   3   8
a   4   9
a   4   10
a   4   11
a   4   12
a   1   13
a   1   14
a   3   15
a   3   16
a   5   17
a   5   18
a   5   19
a   5   20

此数据只是用户(uid)在不同时间(时间戳)穿越不同地区(区域)。目前,为简单起见,时间戳显示为“int”。请注意,上面的数据帧不一定按时间戳的递增顺序排列。此外,不同用户之间可能存在一些行。为简单起见,我仅以单调递增的时间戳顺序显示单个用户的数据帧。

我的目标是 - 找到用户'a'花了多少时间在每个地区以及按什么顺序?所以我的最终预期输出看起来像

uid region  regionTimeStart regionTimeEnd
a   1   1   5
a   2   5   8
a   3   8   9
a   4   9   13
a   1   13  15
a   3   15  17
a   5   17  20

根据我的发现,Spark SQL Window函数可用于此目的。 我尝试过以下的事情,

val w = Window
  .partitionBy("region")
  .partitionBy("uid")
  .orderBy("timestamp")

val resultDF = UserData.select(
  UserData("uid"), UserData("timestamp"),
  UserData("region"), rank().over(w).as("Rank"))

但是在此之后,我不确定如何获取regionTimeStartregionTimeEnd列。 regionTimeEnd列只是regionTimeStart的“引导”,除了组中的最后一个条目。

我看到聚合操作具有'第一'和'最后'功能,但为此我需要基于('uid','region')对数据进行分组,这会破坏遍历的路径的单调递增顺序,即在时间13,14用户已回到区域'1',我希望保留该区域,而不是在第1区将其保留为初始区域'1'。

如果有人可以指导我,那将非常有帮助。我是Spark的新手,与Python / JAVA Spark API相比,我对Scala Spark API有了更好的理解。

1 个答案:

答案 0 :(得分:2)

窗口函数确实很有用,尽管只有当您假设用户只访问给定区域一次时,您的方法才有效。您使用的窗口定义也是错误的 - 对partitionBy的多次调用只返回具有不同窗口定义的新对象。如果您想按多列进行分区,则应通过一次调用(.partitionBy("region", "uid"))传递它们。

让我们从标记每个地区的连续访问开始:

import org.apache.spark.sql.functions.{lag, sum, not}
import org.apache.spark.sql.expressions.Window 

val w = Window.partitionBy($"uid").orderBy($"timestamp")

val change = (not(lag($"region", 1).over(w) <=> $"region")).cast("int")
val ind = sum(change).over(w)

val dfWithInd = df.withColumn("ind", ind)

接下来,我们只是汇总群组并找到潜在客户:

import org.apache.spark.sql.functions.{lead, coalesce}

val regionTimeEnd = coalesce(lead($"timestamp", 1).over(w), $"max_")

val result = dfWithInd
  .groupBy($"uid", $"region", $"ind")
  .agg(min($"timestamp").alias("timestamp"), max($"timestamp").alias("max_"))
  .drop("ind")
  .withColumn("regionTimeEnd", regionTimeEnd)
  .withColumnRenamed("timestamp", "regionTimeStart")
  .drop("max_")

result.show

// +---+------+---------------+-------------+
// |uid|region|regionTimeStart|regionTimeEnd|
// +---+------+---------------+-------------+
// |  a|     1|              1|            5|
// |  a|     2|              5|            8|
// |  a|     3|              8|            9|
// |  a|     4|              9|           13|
// |  a|     1|             13|           15|
// |  a|     3|             15|           17|
// |  a|     5|             17|           20|
// +---+------+---------------+-------------+