在给定分区ID的情况下,有没有办法获得spark RDD分区中的元素数量?不扫描整个分区。
这样的事情:
Rdd.partitions().get(index).size()
除非我没有看到这样的火花API。有任何想法吗?解决方法?
由于
答案 0 :(得分:22)
以下为您提供了一个新的RDD,其元素是每个分区的大小:
rdd.mapPartitions(iter => Array(iter.size).iterator, true)
答案 1 :(得分:16)
PySpark:
num_partitions = 20000
a = sc.parallelize(range(int(1e6)), num_partitions)
l = a.glom().map(len).collect() # get length of each partition
print(min(l), max(l), sum(l)/len(l), len(l)) # check if skewed
火花/阶:
val numPartitions = 20000
val a = sc.parallelize(0 until 1e6.toInt, numPartitions )
val l = a.glom().map(_.length).collect() # get length of each partition
print(l.min, l.max, l.sum/l.length, l.length) # check if skewed
致谢:Mike Dusenberry @ https://issues.apache.org/jira/browse/SPARK-17817
对于数据帧也是如此,而不仅仅是RDD。 只需将DF.rdd.glom ...添加到上面的代码中。
答案 2 :(得分:2)
rdd.mapPartitions(iter => Iterator(iter.size), true).collect()
P.S。不确定他的答案是否实际上做得更多,因为Iterator.apply可能会将其参数转换为数组。
答案 3 :(得分:0)
我知道我来晚了一点,但是我还有另一种方法可以利用spark的内置函数来获取分区中元素的数量。它适用于2.1以上的spark版本。
说明: 我们将创建一个示例数据帧(df),获取分区ID,对分区ID进行分组,并对每个记录进行计数。
Pyspark:
>>> from pyspark.sql.functions import spark_partition_id, count as _count
>>> df = spark.sql("set -v").unionAll(spark.sql("set -v")).repartition(4)
>>> df.rdd.getNumPartitions()
4
>>> df.withColumn("partition_id", spark_partition_id()).groupBy("partition_id").agg(_count("key")).orderBy("partition_id").show()
+------------+----------+
|partition_id|count(key)|
+------------+----------+
| 0| 48|
| 1| 44|
| 2| 32|
| 3| 48|
+------------+----------+
斯卡拉:
scala> val df = spark.sql("set -v").unionAll(spark.sql("set -v")).repartition(4)
df: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [key: string, value: string ... 1 more field]
scala> df.rdd.getNumPartitions
res0: Int = 4
scala> df.withColumn("partition_id", spark_partition_id()).groupBy("partition_id").agg(count("key")).orderBy("partition_id").show()
+------------+----------+
|partition_id|count(key)|
+------------+----------+
| 0| 48|
| 1| 44|
| 2| 32|
| 3| 48|
+------------+----------+