如何使用Spark递归聚合树状(分层)数据?

时间:2018-09-26 01:55:52

标签: scala apache-spark

给出一个表示树状分层结构的数据集,例如:

+-------+--------+
|childId|parentId|
+-------+--------+
|      1|       0|
|      2|       1|
|      3|       1|
|      4|       2|
|      5|       2|
|      6|       2|
|      7|       3|
|      8|       3|
|      9|       3|
|     10|       4|
+-------+--------+

如何使用Spark将其汇总如下?这样,对于树的每个节点,其所有子代,孙代等(直到叶子)都可以聚合:

+--------+--------------------+-----+
|parentId|            children|count|
+--------+--------------------+-----+
|       1|[15, 9, 16, 2, 17...|   16|
|       3|[15, 9, 16, 17, 7...|    7|
|       4|    [12, 13, 10, 11]|    4|
|       7|    [15, 16, 17, 14]|    4|
|       2|[12, 13, 5, 6, 10...|    7|
|       0|[15, 9, 1, 16, 2,...|   17|
+--------+--------------------+-----+

可以找到示例数据文件here

1 个答案:

答案 0 :(得分:0)

给出:

  case class Edge(childId: Int, parentId: Int)

  val edges: Dataset[Edge] = spark.read
    .option("header", value = true)
    .option("inferSchema", value = true)
    .csv("data/tree/edges.csv")
    .as[Edge]

实施类似于BFS的递归算法,如下所示:

def bfs(edges: Dataset[Edge]): Dataset[Edge] = {
    @tailrec
    def helper(n: Dataset[Edge], accum: Dataset[Edge]): Dataset[Edge] = {
      val newN = n.as("n")
        .join(edges.as("plus1"), $"n.childId" === $"plus1.parentId")
        .select($"plus1.childId", $"n.parentId")
        .as[Edge]

      if (newN.count() == 0) accum else helper(newN, accum.union(newN))
    }

    edges.cache()

    helper(edges, edges)
  }

然后,调用如下:

bfs(edges)
    .groupBy($"parentId")
    .agg(
      collect_set($"childId").alias("children"),
      countDistinct($"childId").alias("count")
    )

完成Scala实现here。不知道是否还有其他更简便,更优雅的方法来完成此任务。