我需要从一个大的pyspark数据帧中收集分区/批处理,以便我可以迭代地将它们提供给神经网络
我的想法是1)分区数据,2)迭代收集每个分区,3)用toPandas()
转换收集的分区
我对foreachPartition
和mapPartitions
等方法感到困惑,因为我无法对它们进行迭代。有什么想法吗?
答案 0 :(得分:3)
您可以使用mapPartitions
将每个分区映射到元素列表中,并使用toLocalIterator
以迭代方式获取它们:
for partition in rdd.mapPartitions(lambda part: [list(part)]).toLocalIterator():
print(len(partition)) # or do something else :-)