有没有办法为Spark RDD采样指定数量的元素而不是百分比?

时间:2017-01-24 02:23:24

标签: apache-spark rdd

我目前需要在Spark中为k元素随机抽取RDD中的项目。我注意到有takeSample方法。方法签名如下。

takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] 

但是,这不会返回RDD。还有另一种采样方法可以返回RDD,sample

sample(withReplacement: Boolean, fraction: Double, seed: Long = Utils.random.nextLong): RDD[T]

我不想使用第一种方法takeSample,因为它不会返回RDD并且会将大量数据提取回驱动程序(内存问题)。我继续使用sample方法,但我必须按如下方式计算fraction(百分比)。

val rdd = sc.textFile("some/path") //creates the rdd
val N = rdd.count() //total items in the rdd
val fraction = k / N.toDouble
val sampledRdd = rdd.sample(false, fraction, 67L)

这种方法/方法的问题在于我可能无法获得具有正好k项的RDD。例如,如果我们假设N = 10,那么

  • k = 2,分数= 20%,采样项目= 2
  • k = 3,分数= 30%,抽样项目= 3

但是N = 11,那么

  • k = 2,分数= 18.1818%,抽样项目=?
  • k = 3,分数= 27.2727%,采样项目=?

在最后一个例子中,对于分数= 18.1818%,生成的RDD中有多少项?

此外,这是documentation关于分数参数的说法。

expected size of the sample as a fraction of this RDD's size 
 - without replacement: probability that each element is chosen; fraction must be [0, 1] 
 - with replacement: expected number of times each element is chosen; fraction must be greater than or equal to 0

由于我选择without replacement,我的分数似乎应按如下方式计算。请注意,每个项目具有相同的选择概率(这是我想要表达的)。

val N = rdd.count()
val fraction = 1 / N.toDouble
val sampleRdd = rdd.sample(false, fraction, 67L)

那么,是k / N还是1 / N?似乎文档指向了样本大小和采样概率的所有不同方向。

最后,文档说明。

  

这不能保证准确提供给定RDD的计数部分。

然后,我回到原来的问题/关注:如果RDD API不能保证从RDD中精确采样k个项目,我们如何有效地这样做?

当我写这篇文章时,我发现已经another SO post提出了几乎相同的问题。我发现接受的答案是不可接受的。在这里,我还想澄清分数参数。

我想知道是否有办法使用数据集和数据框架?

1 个答案:

答案 0 :(得分:1)

这个解决方案不是那么漂亮,但我希望这对思考有帮助。 诀窍是使用额外的分数并获得第k个最大分数作为阈值。

val k = 100
val rdd = sc.parallelize(0 until 1000)
val rddWithScore = rdd.map((_, Math.random))
rddWithScore.cache()
val threshold = rddWithScore.map(_._2)
  .sortBy(t => t)
  .zipWithIndex()
  .filter(_._2 == k)
  .collect()
  .head._1
val rddSample = rddWithScore.filter(_._2 < threshold).map(_._1)
rddSample.count()

输出为

k: Int = 100
rdd: org.apache.spark.rdd.RDD[Int] = ParallelCollectionRDD[58] at parallelize at <console>:31
rddWithScore: org.apache.spark.rdd.RDD[(Int, Double)] = MapPartitionsRDD[59] at map at <console>:32
threshold: Double = 0.1180443408900893
rddSample: org.apache.spark.rdd.RDD[Int] = MapPartitionsRDD[69] at map at <console>:40
res10: Long = 100