我们有如下数据框:
+------+--------------------+
| Flag | value|
+------+--------------------+
|1 |5 |
|1 |4 |
|1 |3 |
|1 |5 |
|1 |6 |
|1 |4 |
|1 |7 |
|1 |5 |
|1 |2 |
|1 |3 |
|1 |2 |
|1 |6 |
|1 |9 |
+------+--------------------+
在正常的cumsum之后,我们得到了这个。
+------+--------------------+----------+
| Flag | value|cumsum |
+------+--------------------+----------+
|1 |5 |5 |
|1 |4 |9 |
|1 |3 |12 |
|1 |5 |17 |
|1 |6 |23 |
|1 |4 |27 |
|1 |7 |34 |
|1 |5 |39 |
|1 |2 |41 |
|1 |3 |44 |
|1 |2 |46 |
|1 |6 |52 |
|1 |9 |61 |
+------+--------------------+----------+
现在我们想要的是在为ex设置特定条件时重置cumsum。当它越过20时。
预期输出如下:
+------+--------------------+----------+---------+
| Flag | value|cumsum |expected |
+------+--------------------+----------+---------+
|1 |5 |5 |5 |
|1 |4 |9 |9 |
|1 |3 |12 |12 |
|1 |5 |17 |17 |
|1 |6 |23 |23 |
|1 |4 |27 |4 | <-----reset
|1 |7 |34 |11 |
|1 |5 |39 |16 |
|1 |2 |41 |18 |
|1 |3 |44 |21 |
|1 |2 |46 |2 | <-----reset
|1 |6 |52 |8 |
|1 |9 |61 |17 |
+------+--------------------+----------+---------+
这就是我们计算累积金额的方法。
win_counter = Window.partitionBy("flag")
df_partitioned = df_partitioned.withColumn('cumsum',F.sum(F.col('value')).over(win_counter))
答案 0 :(得分:1)
最好在这里使用use File::Temp qw(tmpnam);
。
pandas_udf
结果:
from pyspark.sql.functions import pandas_udf, PandasUDFType
pdf = pd.DataFrame({'flag':[1]*13,'id':range(13), 'value': [5,4,3,5,6,4,7,5,2,3,2,6,9]})
df = spark.createDataFrame(pdf)
df = df.withColumn('cumsum', F.lit(math.inf))
@pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
def _calc_cumsum(pdf):
pdf.sort_values(by=['id'], inplace=True, ascending=True)
cumsums = []
prev = None
reset = False
for v in pdf['value'].values:
if prev is None:
cumsums.append(v)
prev = v
else:
prev = prev + v if not reset else v
cumsums.append(prev)
reset = True if prev >= 20 else False
pdf['cumsum'] = cumsums
return pdf
df = df.groupby('flag').apply(_calc_cumsum)
df.show()
答案 1 :(得分:0)
基于@niuer的解决方案,我使用 GroupBy 创建了另一种解决方案。在这种情况下,最终的DataFrame将没有 value 列,只有标志和总和。
Requirements:
pyspark => 3.0.0
pandas >= 0.23.4
PyArrow >= 0.15.1
代码:
import pyspark.sql.functions as f
from pyspark import Row
from pyspark.shell import spark
def __create_rows():
for value in [5, 4, 3, 5, 6, 4, 7, 5, 2, 3, 2, 6, 9]:
yield Row(Flag=1, value=value)
df = spark.createDataFrame(data=list(__create_rows()))
@f.pandas_udf('array<int>', f.PandasUDFType.GROUPED_AGG)
def cumsum(iterator):
def iterate():
total = 0
for value in iterator.values:
if total > 20:
total = 0
total = total + value
yield total
return list(iterate())
df = (df
.groupby('flag')
.agg(cumsum(f.col('value')).alias('cumsum')))
df = df.withColumn('cumsum', f.explode('cumsum'))
df.show()
输出:
+----+------+
|flag|cumsum|
+----+------+
| 1| 5|
| 1| 9|
| 1| 12|
| 1| 17|
| 1| 23|
| 1| 4|
| 1| 11|
| 1| 16|
| 1| 18|
| 1| 21|
| 1| 2|
| 1| 8|
| 1| 17|
+----+------+