rowNumber()over(partition_index)使用spark-shell中的mapPartitionsWithIndex

时间:2017-05-09 17:44:29

标签: scala apache-spark

我试图在分区中添加分区索引和rownumber到rdd,我做到了。但是当我试图获得最后一个rownumber的值时,我得到零,rownumber数组似乎没有被触及。变量范围问题?

它就像rowNumber()/ count()over(partition_index),但rownumber在一个循环中与分区索引一起添加,所以可能效率更高?

代码如下:

scala> val rdd1 = sc.makeRDD(100 to 110)
rdd1: org.apache.spark.rdd.RDD[Int] = ParallelCollectionRDD[32] at makeRDD at <console>:25

scala> val rownums=new Array[Int](3)
rownums: Array[Int] = Array(0, 0, 0)

scala> val rdd2=rdd1.repartition(3).mapPartitionsWithIndex( (idx, itr) => itr.map(r => (idx, {rownums(idx)+=1;rownums(idx)}, r)) )
rdd2: org.apache.spark.rdd.RDD[(Int, Int, Int)] = MapPartitionsRDD[37] at mapPartitionsWithIndex at <console>:29

scala> rdd2.collect.foreach(println)
(0,1,100)
(0,2,107)
(0,3,104)
(0,4,105)
(0,5,106)
(0,6,110)
(1,1,102)
(1,2,108)
(1,3,103)
(2,1,101)
(2,2,109)

scala> //uneffected??

scala> rownums.foreach(println)
0
0
0

scala> rownums
res20: Array[Int] = Array(0, 0, 0)

我期待(6,3,2)rownums :(

使用累加器解决:

scala> import org.apache.spark.util._
import org.apache.spark.util._

scala> val rownums=new Array[LongAccumulator](3)
rownums: Array[org.apache.spark.util.LongAccumulator] = Array(null, null, null)

scala> for(i <- 0 until rownums.length){rownums(i)=sc.longAccumulator("rownum_"+i)}

scala> val rdd1 = sc.makeRDD(100 to 110)
rdd1: org.apache.spark.rdd.RDD[Int] = ParallelCollectionRDD[92] at makeRDD at <console>:124

scala> val rownums2=new Array[Int](3)
rownums2: Array[Int] = Array(0, 0, 0)

scala> val rdd2=rdd1.repartition(3).mapPartitionsWithIndex( (idx, itr) => itr.map(r => (idx, {rownums2(idx)+=1;rownums(idx).add(1);rownums2(idx)}, r)) )
rdd2: org.apache.spark.rdd.RDD[(Int, Int, Int)] = MapPartitionsRDD[97] at mapPartitionsWithIndex at <console>:130

scala> rdd2.collect.foreach(println)
(0,1,107)                                                                       
(0,2,106)
(0,3,105)
(0,4,110)
(0,5,104)
(0,6,100)
(1,1,102)
(1,2,103)
(1,3,108)
(2,1,109)
(2,2,101)

scala> rownums.foreach(x=>println(x.value))
6
3
2

scala> 

2 个答案:

答案 0 :(得分:1)

Spark在分布式系统中运行。这意味着您无权修改功能之外的元素。

如果要获取包含每个分区计数的数组,则需要将RDD转换为RDD[Int],其中每行是分区的计数,然后收集它。

rdd.mapPartitions(itr => Iterator(itr.size))

如果分区索引很重要,您可以创建并RDD[Int,Int]将其与行数一起包括在内。

rdd.mapPartitionsWithIndex((idx, itr) => Iterator((idx, itr.size)))

答案 1 :(得分:0)

请阅读编程指南中的Understanding closures

  

在执行之前,Spark计算任务的关闭。闭包是那些变量和方法,它们必须是可见的,以便执行程序在RDD上执行其计算(在本例中为foreach())。该闭包被序列化并发送给每个执行者。

     

发送给每个执行程序的闭包内的变量现在是副本,因此,当在foreach函数中引用计数器时,它不再是驱动程序节点上的计数器。驱动程序节点的内存中仍然有一个计数器,但执行程序不再可见!执行程序只能看到序列化闭包中的副本。因此,计数器的最终值仍然为零,因为计数器上的所有操作都引用了序列化闭包中的值。

您正在修改变量的本地副本,而不是原始变量。