Spark合并公共元素集合

时间:2017-07-24 13:53:59

标签: scala apache-spark

我有一个如下所示的DataFrame:

+-----------+-----------+
|  Package  | Addresses |
+-----------+-----------+
| Package 1 | address1  |
| Package 1 | address2  |
| Package 1 | address3  |
| Package 2 | address3  |
| Package 2 | address4  |
| Package 2 | address5  |
| Package 2 | address6  |
| Package 3 | address7  |
| Package 3 | address8  |
| Package 4 | address9  |
| Package 5 | address9  |
| Package 5 | address1  |
| Package 6 | address10 |
| Package 7 | address8  |
+-----------+-----------+

我需要找到在不同包中一起看到的所有地址。示例输出:

+----+------------------------------------------------------------------------+
| Id |                               Addresses                                |
+----+------------------------------------------------------------------------+
|  1 | [address1, address2, address3, address4, address5, address6, address9] |
|  2 | [address7, address8]                                                   |
|  3 | [address10]                                                            |
+----+------------------------------------------------------------------------+

所以,我有DataFrame。我按package(而不是分组)对其进行分组:

val rdd = packages.select($"package", $"address").
  map{
    x => {
      (x(0).toString(), x(1).toString())
    }
  }.rdd.combineByKey(
  (source) => {
    Set[String](source)
  },

  (acc: Set[String], v) => {
    acc + v
  },

  (acc1: Set[String], acc2: Set[String]) => {
    acc1 ++ acc2
  }
)

然后,我合并了具有公共地址的行:

val result = rdd.treeAggregate(
  Set.empty[Set[String]]
)(
  (map: Set[Set[String]], row) => {
    val vals = row._2
    val sets = map + vals

    // copy-paste from here https://stackoverflow.com/a/25623014/772249
    sets.foldLeft(Set.empty[Set[String]])((cum, cur) => {
      val (hasCommon, rest) = cum.partition(_ & cur nonEmpty)
      rest + (cur ++ hasCommon.flatten)
    })
  },
  (map1, map2) => {
    val sets = map1 ++ map2

    // copy-paste from here https://stackoverflow.com/a/25623014/772249
    sets.foldLeft(Set.empty[Set[String]])((cum, cur) => {
      val (hasCommon, rest) = cum.partition(_ & cur nonEmpty)
      rest + (cur ++ hasCommon.flatten)
    })
  },
  10
)

但是,无论我做什么,treeAggregate花了很长时间,我无法完成单项任务。原始数据大小约为250gb。我尝试过不同的群集,但treeAggregate花了太长时间。

treeAggregate之前的所有内容都很有效,但之后却很有效。

我尝试了不同的spark.sql.shuffle.partitions(默认值,2000,10000),但它似乎并不重要。

我为depth尝试了不同的treeAggregate,但没有注意到差异。

相关问题:

  1. Merge Sets of Sets that contain common elements in Scala
  2. Spark complex grouping

1 个答案:

答案 0 :(得分:3)

查看您的数据,就好像它是一个地址是顶点的图形,如果有两个包的话,它们就有连接。然后问题的解决方案将是图表的connected components

Sparks gpraphX库具有优化功能,可以找到connected components。它将返回不同连接组件中的顶点,将它们视为每个连接组件的ID。

然后,如果需要,您可以收集连接到它的所有其他地址。

查看this article他们如何使用图表来实现与您相同的分组。