如何有效地处理PySpark中的嵌套数据?

时间:2018-08-20 21:46:32

标签: python-3.x pyspark pyspark-sql

我在这里遇到了一个情况,发现当项目已经是列表时,spark中的collect_list效率不高。

基本上,我试图计算嵌套列表的均值(保证每个列表的大小都相同)。例如,当数据集变为1000万行时,它可能会产生内存不足错误。最初,我认为它与udf有关(以计算均值)。但是实际上,我发现聚合部分(列表的collect_list)才是真正的问题。

我现在要做的是将1000万行划分为多个块(按“用户”),分别汇总每个块,然后在最后合并它们。关于有效处理嵌套数据有更好的建议吗?

例如,玩具示例如下:

data = [('user1','place1', ['place1', 'place2', 'place3'], [0.0, 0.5, 0.4], [0.0, 0.4, 0.3]),
    ('user1','place2', ['place1', 'place2', 'place3'], [0.7, 0.0, 0.4], [0.6, 0.0, 0.3]),
    ('user2','place1', ['place1', 'place2', 'place3'], [0.0, 0.4, 0.3], [0.0, 0.3, 0.4]),
    ('user2','place3', ['place1', 'place2', 'place3'], [0.1, 0.2, 0.0], [0.3, 0.1, 0.0]),
    ('user3','place2', ['place1', 'place2', 'place3'], [0.3, 0.0, 0.4], [0.2, 0.0, 0.4]),
   ]
data_df = sparkApp.sparkSession.createDataFrame(data, ['user', 'place', 'places', 'data1', 'data2'])

data_agg = data_df.groupBy('user') \
    .agg(f.collect_list('place').alias('place_list'),
         f.first('places').alias('places'),
         f.collect_list('data1').alias('data1'),
         f.collect_list('data1').alias('data2'),
        )

import numpy as np
def average_values(sim_vectors):
    if len(sim_vectors) == 1:
        return sim_vectors[0]
    mat = np.array(sim_vectors)
    mean_vector = np.mean(mat, axis=0)
    return np.round(mean_vector, 3).tolist()

avg_vectors_udf = f.udf(average_values, ArrayType(DoubleType()))
data_agg_ave = data_agg.withColumn('data1', avg_vectors_udf('data1')) \
    .withColumn('data2', avg_vectors_udf('data2'))

结果将是:

+-----+----------------+--------------------+-----------------+-----------------+

| user|      place_list|              places|            data1|            data2|

+-----+----------------+--------------------+-----------------+-----------------+

|user1|[place1, place2]|[place1, place2, ...|[0.35, 0.25, 0.4]|[0.35, 0.25, 0.4]|

|user3|        [place2]|[place1, place2, ...|  [0.3, 0.0, 0.4]|  [0.3, 0.0, 0.4]|

|user2|[place1, place3]|[place1, place2, ...|[0.05, 0.3, 0.15]|[0.05, 0.3, 0.15]|

+-----+----------------+--------------------+-----------------+-----------------+

0 个答案:

没有答案