如何获取分区中的元素数量?

时间:2015-02-24 02:20:41

标签: apache-spark partitioning

在给定分区ID的情况下,有没有办法获得spark RDD分区中的元素数量?不扫描整个分区。

这样的事情:

Rdd.partitions().get(index).size()

除非我没有看到这样的火花API。有任何想法吗?解决方法?

由于

4 个答案:

答案 0 :(得分:22)

以下为您提供了一个新的RDD,其元素是每个分区的大小:

rdd.mapPartitions(iter => Array(iter.size).iterator, true) 

答案 1 :(得分:16)

PySpark:

num_partitions = 20000
a = sc.parallelize(range(int(1e6)), num_partitions)
l = a.glom().map(len).collect()  # get length of each partition
print(min(l), max(l), sum(l)/len(l), len(l))  # check if skewed

火花/阶:

val numPartitions = 20000
val a = sc.parallelize(0 until 1e6.toInt, numPartitions )
val l = a.glom().map(_.length).collect()  # get length of each partition
print(l.min, l.max, l.sum/l.length, l.length)  # check if skewed

致谢:Mike Dusenberry @ https://issues.apache.org/jira/browse/SPARK-17817

对于数据帧也是如此,而不仅仅是RDD。 只需将DF.rdd.glom ...添加到上面的代码中。

答案 2 :(得分:2)

pzecevic的回答是有效的,但从概念上讲,没有必要构造一个数组,然后将其转换为迭代器。我只是直接构造迭代器,然后通过collect调用获取计数。

rdd.mapPartitions(iter => Iterator(iter.size), true).collect()

P.S。不确定他的答案是否实际上做得更多,因为Iterator.apply可能会将其参数转换为数组。

答案 3 :(得分:0)

我知道我来晚了一点,但是我还有另一种方法可以利用spark的内置函数来获取分区中元素的数量。它适用于2.1以上的spark版本。

说明: 我们将创建一个示例数据帧(df),获取分区ID,对分区ID进行分组,并对每个记录进行计数。

Pyspark:

>>> from pyspark.sql.functions import spark_partition_id, count as _count
>>> df = spark.sql("set -v").unionAll(spark.sql("set -v")).repartition(4)
>>> df.rdd.getNumPartitions()
4
>>> df.withColumn("partition_id", spark_partition_id()).groupBy("partition_id").agg(_count("key")).orderBy("partition_id").show()
+------------+----------+
|partition_id|count(key)|
+------------+----------+
|           0|        48|
|           1|        44|
|           2|        32|
|           3|        48|
+------------+----------+

斯卡拉:

scala> val df = spark.sql("set -v").unionAll(spark.sql("set -v")).repartition(4)
df: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [key: string, value: string ... 1 more field]

scala> df.rdd.getNumPartitions
res0: Int = 4

scala> df.withColumn("partition_id", spark_partition_id()).groupBy("partition_id").agg(count("key")).orderBy("partition_id").show()
+------------+----------+
|partition_id|count(key)|
+------------+----------+
|           0|        48|
|           1|        44|
|           2|        32|
|           3|        48|
+------------+----------+