Pyspark:具有重置条件的累积总和

时间:2017-11-18 09:52:33

标签: pyspark cumulative-sum

我们有如下数据框:

+------+--------------------+
| Flag |               value|
+------+--------------------+
|1     |5                   |
|1     |4                   |
|1     |3                   |
|1     |5                   |
|1     |6                   |
|1     |4                   |
|1     |7                   |
|1     |5                   |
|1     |2                   |
|1     |3                   |
|1     |2                   |
|1     |6                   |
|1     |9                   |      
+------+--------------------+

在正常的cumsum之后,我们得到了这个。

+------+--------------------+----------+
| Flag |               value|cumsum    |
+------+--------------------+----------+
|1     |5                   |5         |
|1     |4                   |9         |
|1     |3                   |12        |
|1     |5                   |17        |
|1     |6                   |23        |
|1     |4                   |27        |
|1     |7                   |34        |
|1     |5                   |39        |
|1     |2                   |41        |
|1     |3                   |44        |
|1     |2                   |46        |
|1     |6                   |52        |
|1     |9                   |61        |       
+------+--------------------+----------+

现在我们想要的是在为ex设置特定条件时重置cumsum。当它越过20时。

预期输出如下:

+------+--------------------+----------+---------+
| Flag |               value|cumsum    |expected |
+------+--------------------+----------+---------+
|1     |5                   |5         |5        |
|1     |4                   |9         |9        |
|1     |3                   |12        |12       |
|1     |5                   |17        |17       |
|1     |6                   |23        |23       |
|1     |4                   |27        |4        |  <-----reset 
|1     |7                   |34        |11       |
|1     |5                   |39        |16       |
|1     |2                   |41        |18       |
|1     |3                   |44        |21       |
|1     |2                   |46        |2        |  <-----reset
|1     |6                   |52        |8        |
|1     |9                   |61        |17       |         
+------+--------------------+----------+---------+

这就是我们计算累积金额的方法。

win_counter = Window.partitionBy("flag")

df_partitioned = df_partitioned.withColumn('cumsum',F.sum(F.col('value')).over(win_counter))

2 个答案:

答案 0 :(得分:1)

最好在这里使用use File::Temp qw(tmpnam);

pandas_udf

结果:

from pyspark.sql.functions import pandas_udf, PandasUDFType

pdf = pd.DataFrame({'flag':[1]*13,'id':range(13), 'value': [5,4,3,5,6,4,7,5,2,3,2,6,9]})
df = spark.createDataFrame(pdf)
df = df.withColumn('cumsum', F.lit(math.inf))

@pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
def _calc_cumsum(pdf):
    pdf.sort_values(by=['id'], inplace=True, ascending=True)
    cumsums = []
    prev = None
    reset = False
    for v in pdf['value'].values:
        if prev is None:
            cumsums.append(v)
            prev = v
        else:
            prev = prev + v if not reset else v
            cumsums.append(prev)
            reset = True if prev >= 20 else False
            
    pdf['cumsum'] = cumsums
    return pdf

df = df.groupby('flag').apply(_calc_cumsum)
df.show()

答案 1 :(得分:0)

基于@niuer的解决方案,我使用 GroupBy 创建了另一种解决方案。在这种情况下,最终的DataFrame将没有 value 列,只有标志和总和。

Requirements:
pyspark => 3.0.0
pandas >= 0.23.4
PyArrow >= 0.15.1

代码:

import pyspark.sql.functions as f
from pyspark import Row
from pyspark.shell import spark


def __create_rows():
    for value in [5, 4, 3, 5, 6, 4, 7, 5, 2, 3, 2, 6, 9]:
        yield Row(Flag=1, value=value)


df = spark.createDataFrame(data=list(__create_rows()))


@f.pandas_udf('array<int>', f.PandasUDFType.GROUPED_AGG)
def cumsum(iterator):
    def iterate():
        total = 0
        for value in iterator.values:
            if total > 20:
                total = 0

            total = total + value
            yield total

    return list(iterate())


df = (df
      .groupby('flag')
      .agg(cumsum(f.col('value')).alias('cumsum')))
df = df.withColumn('cumsum', f.explode('cumsum'))
df.show()

输出:

+----+------+
|flag|cumsum|
+----+------+
|   1|     5|
|   1|     9|
|   1|    12|
|   1|    17|
|   1|    23|
|   1|     4|
|   1|    11|
|   1|    16|
|   1|    18|
|   1|    21|
|   1|     2|
|   1|     8|
|   1|    17|
+----+------+