我创建了一个合成数据集,并尝试基于一列进行重新分区。目标是最终获得平衡(相同大小)数量的分区,但我无法实现这一点。有没有办法做到这一点,最好不要求助于 RDD 并保存数据帧?
示例代码:
from pyspark.sql import SparkSession
from pyspark.sql.types import *
import pyspark.sql.functions as f
spark = SparkSession.builder.appName('learn').getOrCreate()
import pandas as pd
import random
from pyspark.sql.types import *
nr = 500
data = {'id': [random.randint(0,5) for _ in range(nr)], 'id2': [random.randint(0,5) for _ in range(nr)]}
data = pd.DataFrame(data)
df = spark.createDataFrame(data)
# df.show()
df = df.repartition(3, 'id')
# see the different partitions
for ipart in range(3):
print(f'partition {ipart}')
def fpart(partition_idx, iterator, target_partition_idx=ipart):
if partition_idx == target_partition_idx:
return iterator
else:
return iter(())
res = df.rdd.mapPartitionsWithIndex(fpart)
res = res.toDF(schema=schema)
# res.show(n=5, truncate=False)
print(f"number of rows {res.count()}, unique ids {res.select('id').drop_duplicates().toPandas()['id'].tolist()}")
它产生:
partition 0
number of rows 79, unique ids [3]
partition 1
number of rows 82, unique ids [0]
partition 2
number of rows 339, unique ids [5, 1, 2, 4]
所以分区显然不平衡。
我在 How to guarantee repartitioning in Spark Dataframe 中看到这是可以解释的,因为分配给分区是基于列 id 模 3(分区数)的哈希值:
df.select('id', f.expr("hash(id)"), f.expr("pmod(hash(id), 3)")).drop_duplicates().show()
产生
+---+-----------+-----------------+
| id| hash(id)|pmod(hash(id), 3)|
+---+-----------+-----------------+
| 3| 519220707| 0|
| 0|-1670924195| 1|
| 1|-1712319331| 2|
| 5| 1607884268| 2|
| 4| 1344313940| 2|
| 2| -797927272| 2|
+---+-----------+-----------------+
但我觉得这很奇怪。在重新分区函数中指定列的目的是以某种方式将 id 的值拆分到不同的分区。如果列 id 的唯一值比本例中的 6 多,效果会更好,但仍然如此。
有没有办法做到这一点?