我目前需要在Spark中为k元素随机抽取RDD中的项目。我注意到有takeSample
方法。方法签名如下。
takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T]
但是,这不会返回RDD。还有另一种采样方法可以返回RDD,sample
。
sample(withReplacement: Boolean, fraction: Double, seed: Long = Utils.random.nextLong): RDD[T]
我不想使用第一种方法takeSample
,因为它不会返回RDD并且会将大量数据提取回驱动程序(内存问题)。我继续使用sample
方法,但我必须按如下方式计算fraction
(百分比)。
val rdd = sc.textFile("some/path") //creates the rdd
val N = rdd.count() //total items in the rdd
val fraction = k / N.toDouble
val sampledRdd = rdd.sample(false, fraction, 67L)
这种方法/方法的问题在于我可能无法获得具有正好k项的RDD。例如,如果我们假设N = 10,那么
但是N = 11,那么
在最后一个例子中,对于分数= 18.1818%,生成的RDD中有多少项?
此外,这是documentation关于分数参数的说法。
expected size of the sample as a fraction of this RDD's size - without replacement: probability that each element is chosen; fraction must be [0, 1] - with replacement: expected number of times each element is chosen; fraction must be greater than or equal to 0
由于我选择without replacement
,我的分数似乎应按如下方式计算。请注意,每个项目具有相同的选择概率(这是我想要表达的)。
val N = rdd.count()
val fraction = 1 / N.toDouble
val sampleRdd = rdd.sample(false, fraction, 67L)
那么,是k / N
还是1 / N
?似乎文档指向了样本大小和采样概率的所有不同方向。
最后,文档说明。
这不能保证准确提供给定RDD的计数部分。
然后,我回到原来的问题/关注:如果RDD API不能保证从RDD中精确采样k个项目,我们如何有效地这样做?
当我写这篇文章时,我发现已经another SO post提出了几乎相同的问题。我发现接受的答案是不可接受的。在这里,我还想澄清分数参数。
我想知道是否有办法使用数据集和数据框架?
答案 0 :(得分:1)
这个解决方案不是那么漂亮,但我希望这对思考有帮助。 诀窍是使用额外的分数并获得第k个最大分数作为阈值。
val k = 100
val rdd = sc.parallelize(0 until 1000)
val rddWithScore = rdd.map((_, Math.random))
rddWithScore.cache()
val threshold = rddWithScore.map(_._2)
.sortBy(t => t)
.zipWithIndex()
.filter(_._2 == k)
.collect()
.head._1
val rddSample = rddWithScore.filter(_._2 < threshold).map(_._1)
rddSample.count()
输出为
k: Int = 100
rdd: org.apache.spark.rdd.RDD[Int] = ParallelCollectionRDD[58] at parallelize at <console>:31
rddWithScore: org.apache.spark.rdd.RDD[(Int, Double)] = MapPartitionsRDD[59] at map at <console>:32
threshold: Double = 0.1180443408900893
rddSample: org.apache.spark.rdd.RDD[Int] = MapPartitionsRDD[69] at map at <console>:40
res10: Long = 100