如何堆叠两列进行分组?

时间:2018-06-22 16:37:09

标签: scala apache-spark apache-spark-sql

我有以下DataFrame df,它表示具有节点A,B,C和D的图。每个节点都属于组1或2:

src   dst   group_src   group_dst
A     B     1           1
A     B     1           1
B     A     1           1
A     C     1           2
C     D     2           2
D     C     2           2

我需要计算不同的节点数和每组的边数。结果应为以下内容:

group   nodes_count    edges_count
1       2              3
2       2              2

不考虑边缘A-> C,因为节点属于不同的组。

我不知道如何堆叠列group_srcgroup_dst以便按唯一列group进行分组。而且我也不知道如何计算组内的边数。

df
  .groupBy("group_src","group_dst")
  .agg(countDistinct("srcId","dstId").as("nodes_count"))

1 个答案:

答案 0 :(得分:2)

我认为可能有必要使用两个步骤:

val edges = df.filter($"group_src" === $"group_dst")
  .groupBy($"group_src".as("group"))
  .agg(count("*").as("edges_count"))

val nodes = df.select($"src".as("id"), $"group_src".as("group"))
  .union(df.select($"dst".as("id"), $"group_dst".as("group"))
  .groupBy("group").agg(countDistinct($"id").as("nodes_count"))

nodes.join(edges, "group")

选择特定的列后,您可以使用.union()完成列的“堆叠”。