流动总和/累计总和与上下限Py Spark

时间:2019-12-06 12:11:46

标签: apache-spark pyspark pyspark-sql pyspark-dataframes

我是火花的新手,我正在尝试计算以0为底,以8为上限的窗口运行总和。

下面是一个玩具示例(请注意,实际数据更接近数百万行):

import pyspark.sql.functions as F
from pyspark.sql import Window
import pandas as pd
from pyspark.sql.functions import pandas_udf, PandasUDFType

pdf = pd.DataFrame({'ids':    [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
                    'day':    [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4], 
                    'counts': [-3, 3, -6, 3, 3, 6, -3, -6, 3, 3, 3, -3]})
sdf = spark.createDataFrame(pdf)
sdf = sdf.orderBy(sdf.ids,sdf.day)

这将创建表格

+----+---+-------+
|aIds|day|eCounts|
+----+---+-------+
|   1|  1|     -3|
|   1|  2|      3|
|   1|  3|     -6|
|   1|  4|      3|
|   2|  1|      3|
|   2|  2|      6|
|   2|  3|     -3|
|   2|  4|     -6|
|   3|  1|      3|
|   3|  2|      3|
|   3|  3|      3|
|   3|  4|     -3|
+----+---+-------+

下面是一个求和结果的示例,以及预期输出runSumCap

+----+---+-------+------+---------+
|aIds|day|eCounts|runSum|runSumCap|
+----+---+-------+------+---------+
|   1|  1|     -3|    -3|        0| <-- reset to 0
|   1|  2|      3|     0|        3|
|   1|  3|     -6|    -6|        0| <-- reset to 0
|   1|  4|      3|    -3|        3|
|   2|  1|      3|     3|        3|
|   2|  2|      6|     9|        8| <-- reset to 8
|   2|  3|     -3|     6|        5| 
|   2|  4|     -6|     0|        0| <-- reset to 0
|   3|  1|      3|     3|        3|
|   3|  2|      3|     6|        6|
|   3|  3|      3|     9|        8| <-- reset to 8
|   3|  4|     -3|     6|        5|
+----+---+-------+------+---------+

我知道我可以将运行总和计算为

partition = Window.partitionBy('aIds').orderBy('aIds','day').rowsBetween(Window.unboundedPreceding, Window.currentRow)`
sdf1 = sdf.withColumn('runSum',F.sum(sdf.eCounts).over(partition))
sdf1.orderBy('aIds','day').show()

为了达到预期效果,我尝试查看@pandas_udf来修改总和:

@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def runSumCap(counts):
    #counts columns is passed as a pandas series
    floor = 0
    cap = 8
    runSum = 0
    runSumList = []
    for count in counts.tolist():
      runSum = runSum + count
      if(runSum > cap):
        runSum = 8
      elif(runSum < floor ):
        runSum = 0
      runSumList += [runSum]
    return pd.Series(runSumList)


partition = Window.partitionBy('aIds').orderBy('aIds','day').rowsBetween(Window.unboundedPreceding, Window.currentRow)
sdf1 = sdf.withColumn('runSum',runSumCap(sdf['counts']).over(partition))

但是,这不起作用,而且这似乎不是最有效的方法。 我该如何工作?有没有办法让它保持平行,还是我必须去熊猫数据框

编辑: 对当前列进行了一些整理以对数据集进行排序,并对我要实现的目标有了更多见解

EDIT2: @DrChess提供的答案几乎可以得出正确的结果,但是由于某种原因,该系列与正确的日期不匹配:

+----+---+-------+------+
|aIds|day|eCounts|runSum|
+----+---+-------+------+
|   1|  1|     -3|     0|
|   1|  2|      3|     0|
|   1|  3|     -6|     3|
|   1|  4|      3|     3|
|   2|  1|      3|     3|
|   2|  2|      6|     8|
|   2|  3|     -3|     0|
|   2|  4|     -6|     5|
|   3|  1|      3|     6|
|   3|  2|      3|     3|
|   3|  3|      3|     8|
|   3|  4|     -3|     5|
+----+---+-------+------+

2 个答案:

答案 0 :(得分:1)

不幸的是,类型为pandas_udf的{​​{1}}的窗口函数不能用于有界窗口函数(GROUPED_AGG)。当前仅适用于无边界窗口,即.rowsBetween(Window.unboundedPreceding, Window.currentRow)。另外,输入是.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing),但是输出应该是所提供类型的常量。因此,您将无法实现部分聚合。

您可以使用pandas.Series的{​​{1}} GROUPED_MAP。 这里是一些代码:

pandas_udf

答案 1 :(得分:1)

我找到了一种方法,首先在每行中创建一个数组(使用collect_list作为窗口函数),该数组包含用于使运行总和到那时为止的值。 然后,我定义了一个udf(无法使用pandas_udf进行此工作),并且此工作正常。 下面是完整的可复制示例:

import pyspark.sql.functions as F
from pyspark.sql import Window
import pandas as pd
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import *
import numpy as np

def accumalate(iterable):
    total = 0
    ceil = 8
    floor = 0
    for element in iterable:
        total = total + element
        if (total > ceil):
          total = ceil
        elif (total < floor):
          total = floor
    return total

pdf = pd.DataFrame({'aIds':    [1,  1,  1,  1, 2, 2,  2,  2, 3, 3, 3,  3],
                    'day':    [1,  2,  3,  4, 1, 2,  3,  4, 1, 2, 3,  4],
                    'eCounts': [-3, 3, -6,  3, 3, 6, -3, -6, 3, 3, 3, -3]})

sdf = spark.createDataFrame(pdf)
sdf = sdf.orderBy(sdf.aIds,sdf.day)

runSumCap = F.udf(accumalate,LongType())
partition = Window.partitionBy('aIds').orderBy('aIds','day').rowsBetween(Window.unboundedPreceding, Window.currentRow)
sdf1 = sdf.withColumn('splitWindow',F.collect_list(sdf.eCounts).over(partition))
sdf2 = sdf1.withColumn('runSumCap',runSumCap(sdf1.splitWindow))
sdf2.orderBy('aIds','day').show()

这会产生预期的结果:

+----+---+-------+--------------+---------+
|aIds|day|eCounts|   splitWindow|runSumCap|
+----+---+-------+--------------+---------+
|   1|  1|     -3|          [-3]|        0|
|   1|  2|      3|       [-3, 3]|        3|
|   1|  3|     -6|   [-3, 3, -6]|        0|
|   1|  4|      3|[-3, 3, -6, 3]|        3|
|   2|  1|      3|           [3]|        3|
|   2|  2|      6|        [3, 6]|        8|
|   2|  3|     -3|    [3, 6, -3]|        5|
|   2|  4|     -6|[3, 6, -3, -6]|        0|
|   3|  1|      3|           [3]|        3|
|   3|  2|      3|        [3, 3]|        6|
|   3|  3|      3|     [3, 3, 3]|        8|
|   3|  4|     -3| [3, 3, 3, -3]|        5|
+----+---+-------+--------------+---------+