从Spark中的每个分区获取N个值

时间:2016-07-27 14:11:50

标签: scala apache-spark

假设我有以下数据:

val DataSort = Seq(("a",5),("b",13),("b",2),("b",1),("c",4),("a",1),("b",15),("c",3),("c",1))
val DataSortRDD = sc.parallelize(DataSort,2)

现在有两个分区:

scala>DataSortRDD.glom().take(2).head
res53: Array[(String,Int)] = Array(("a",5),("b",13),("b",2),("b",1),("c",4))
scala>DataSortRDD.glom().take(2).tail
res54: Array[(String,Int)] = Array(Array(("a",1),("b",15),("c",3),("c",2),("c",1)))

假设在每个分区中数据已经使用类似sortWithinPartitions(col("src").desc,col("rank").desc)的数据进行排序(数据帧的数据,但仅用于说明)。

我想要的是从每个分区得到每个字母的前两个值(如果有超过2个值)。因此,在此示例中,每个分区的结果应为:

scala>HypotheticalRDD.glom().take(2).head
Array(("a",5),("b",13),("b",2),("c",4))
scala>HypotheticalRDD.glom().take(2).tail
Array(Array(("a",1),("b",15),("c",3),("c",2)))

我知道我必须使用mapPartition函数,但我不清楚如何迭代每个分区中的值并得到第一个2.任何提示?

修改:更确切地说。我知道在每个分区中,数据先是按'字母'排序,然后按'计数'排序。所以我的主要想法是mapPartition中的输入函数应遍历分区并yield每个字母的前两个值。这可以通过检查每个迭代.next()值来完成。这就是我在python中编写它的方法:

def limit_on_sorted(iterator):
    oldKey = None
    cnt = 0
    while True:
        elem = iterator.next()
        if not elem:
            return
        curKey = elem[0]
        if curKey == oldKey:
            cnt +=1
            if cnt >= 2:
                yield None
        else:
            oldKey = curKey
            cnt = 0
        yield elem

DataSortRDDpython.mapPartitions(limit_on_sorted,preservesPartitioning=True).filter(lambda x:x!=None)

1 个答案:

答案 0 :(得分:1)

假设您并不真正关心结果的分区,您可以使用 mapPartitionsWithIndex 来合并分区在您 groupBy 的密钥中输入ID,然后您可以轻松地为每个密钥获取前两项:

val result: RDD[(String, Int)] = DataSortRDD
  .mapPartitionsWithIndex {
     // add the partition ID into the "key" of every record:
     case (partitionId, itr) => itr.map { case (k, v) => ((k, partitionId), v) }
   }
  .groupByKey() // groups by letter and partition id
  // take only first two records, and drop partition id
  .flatMap { case ((k, _), itr) => itr.take(2).toArray.map((k, _)) }

println(result.collect().toList)
// prints:
// List((a,5), (b,15), (b,13), (b,2), (a,1), (c,4), (c,3))

请注意,最终结果没有以相同的方式进行分区(groupByKey更改分区),我假设这对您尝试的内容并不重要做(坦率地说,逃脱了我)。

编辑:如果您想避免随机播放并在每个分区中执行所有操作:

val result: RDD[(String, Int)] = DataSortRDD
  .mapPartitions(_.toList.groupBy(_._1).mapValues(_.take(2)).values.flatten.iterator, true)