如何创建RDD,以便通过对原始RDD进行分层采样来创建其分区?

时间:2018-10-26 09:23:43

标签: apache-spark sampling

我有一个具有大量记录和功能的RDD,以及原始的RDD。我想创建另一个RDD,然后是sampledRDD,以便sampledRDD的每个分区都与原始RDD分层。

        Original RDD                            Sampled RDD
+----------+---------------+------+   +----------+---------------+------+   
|rowNumber | other features| label|   |rowNumber | other features| label| 
+----------+---------------+------+   +----------+---------------+------+ 
|1         |some values....|A     |   |1         |some values....|A     #
|2         |some values....|A     |   |2         |some values....|A     # 
|3         |some values....|A     |   |3         |some values....|A     #  
|4         |some values....|A     |   |8         |some values....|B     # <= Partition 0
|5         |some values....|A     |   |9         |some values....|B     #   
|6         |some values....|A     |   |13        |some values....|c     #_______________   
|7         |some values....|A     |   |6         |some values....|A     $  
|8         |some values....|B     |   |7         |some values....|A     $  
|9         |some values....|B     |   |4         |some values....|A     $  
|10        |some values....|B     |   |5         |some values....|A     $ <= Partition 1
|11        |some values....|B     |   |10        |some values....|B     $  
|12        |some values....|c     |   |11        |some values....|B     $  
|13        |some values....|c     |   |12        |some values....|c     $  
+----------+---------------+------+   +----------+---------------+------+  

This image shows a schema of the matter.

我的解决方案在这里,但是它太慢了,似乎不适合大型或大型数据集。

def StratifiedPartitions(
                            data:RDD[Row],
                            rate:Double,
                            nPartition:Int,
                            schema:baseSchema,
                            withReplacement:Boolean = false,
                            seed:Long = System.currentTimeMillis()):RDD[Row] = {

    class partitioner(override val numPartitions: Int) extends Partitioner {
      def getPartition(key: Any): Int = key.toString.toInt
    }

    Random.setSeed(seed)
    val r = Random
    val fractions = data.map(r => r.getByte(schema.cIndex)).distinct.map(x => (x, rate)).collectAsMap
    val samples = (0 until nPartition).map(idx =>
      data
        .map(r => (r.getByte(schema.cIndex), r))
        .sampleByKey(withReplacement, fractions, r.nextInt())
        .map(r => (idx, r._2))
    ).reduce(_ union _)

    samples.partitionBy(new partitioner(nPartition)).map(_._2)
  }

0 个答案:

没有答案