避免PySpark中的for循环

时间:2019-01-25 18:29:15

标签: python-3.x pyspark

我有一个与此类似的PySpark DataFrame:

ID | value | period
a  |  100  |   1   
a  |  100  |   1   
b  |  100  |   1   
a  |  100  |   2   
b  |  100  |   2   
a  |  100  |   3

对于每个期间(1, 2, 3),我想将数据过滤到期间小于或等于该数字的位置,然后对每个ID的值列求和。

例如,周期1将给出(a:200, b:100),周期2将给出(a:300, b:200),周期3将给出(a:400, b:200)

此刻我正在循环执行此操作:

vals = [('a', 100, 1),
        ('a', 100, 1),
        ('b', 100, 1),
        ('a', 100, 2),
        ('b', 100, 2),
        ('a', 100, 3)]
cols = ['ID', 'value', 'period']
df = spark.createDataFrame(vals, cols)

for p in (1, 2, 3):
    df_filter = df[df['period'] <= p]
    results = df_filter.groupBy('ID').agg({'value':'sum'})

然后我将“结果”转换为大熊猫,并将其附加到一个DataFrame中。

是否有更好的方法可以执行此操作而不必使用循环? (实际上,我有数百个期间)。

1 个答案:

答案 0 :(得分:2)

这是使用pysparkpandas的组合解决方案;由于您说了数百个时期,所以这可能是一个可行的解决方案;基本上,首先使用pyspark聚合数据帧,然后将其转换为本地熊猫数据帧以进行进一步处理:

import pyspark.sql.functions as f

local_df = df.groupBy('period').pivot('ID').agg(f.sum('value')).toPandas()

local_df.sort_values('period').fillna(0).set_index('period').cumsum().reset_index()
#   period      a      b
#0       1  200.0  100.0
#1       2  300.0  200.0
#2       3  400.0  200.0