如何在PySpark DataFrame中强制进行某个分区?

时间:2018-06-08 09:04:21

标签: apache-spark pyspark partitioning

假设我有一个带有partition_id列的DataFrame:

n_partitions = 2

df = spark.sparkContext.parallelize([
    [1, 'A'],
    [1, 'B'],
    [2, 'A'],
    [2, 'C']
]).toDF(('partition_id', 'val'))

如何重新分区DataFrame以保证partition_id的每个值都转到一个单独的分区,并且实际分区的数量与partition_id的不同值完全相同?

如果我执行散列分区,即df.repartition(n_partitions, 'partition_id'),则可以保证正确的分区数,但是由于散列冲突,某些分区可能为空,而其他分区可能包含多个partition_id值。

1 个答案:

答案 0 :(得分:6)

Python和DataFrame API没有这样的选项。 Dataset中的分区API不可插入,仅支持预定义的range and hash partitioning schemes

您可以将数据转换为RDD,使用自定义分区程序进行分区,然后将转换回DataFrame

from pyspark.sql.functions import col, struct, spark_partition_id

mapping = {k: i for i, k in enumerate(
    df.select("partition_id").distinct().rdd.flatMap(lambda x: x).collect()
)}

result = (df
    .select("partition_id", struct([c for c in df.columns]))
    .rdd.partitionBy(len(mapping), lambda k: mapping[k])
    .values()
    .toDF(df.schema))

result.withColumn("actual_partition_id", spark_partition_id()).show()
# +------------+---+-------------------+
# |partition_id|val|actual_partition_id|
# +------------+---+-------------------+
# |           1|  A|                  0|
# |           1|  B|                  0|
# |           2|  A|                  1|
# |           2|  C|                  1|
# +------------+---+-------------------+

请记住,这只会创建特定的数据分布,并且不会设置Catalyst优化器可以使用的分区程序。