超过阈值后重置的累积和

时间:2019-05-18 17:30:02

标签: pyspark

我对pyspark和python比较陌生。这可能是微不足道的,但我不太明白。

我有一个数据集,其中包含一个ID,一个DATE列和一个包含浮点数的X列。我想做的是,根据ID并按DATE排序,计算X的运行总计。当该运行总计超过值Y时,然后重新启动总计。下面是一个示例,其中Y = 20:

|ID |DATE        | X   | cumsum_X |
-----------------------------------
|  1|  2017-03-01|   10|        10|
|  1|  2017-03-02|   12|        22|
|  1|  2017-03-03|    5|         5|
|  1|  2017-03-04|   10|        15|
|  2|  2015-01-01|    6|         6|
|  2|  2015-01-03|    7|        13|

我已经能够计算出该组的累积总和,并使用下面的代码正确地排序了,但是我不确定如何修改它以获得所需的重置行为。

win =(Window(partitionBy('ID').over('DATE'))
      .rangeBetween(Window.unboundedPreceding, 0))

df = df.withColumn('cumsum_x', F.Sum('X').over(win))

1 个答案:

答案 0 :(得分:0)

这是一个使用小技巧来重置累积总和的解决方案。

首先,我计算每个 ID 的运行总和。之后,我将结果总和分组到取决于重置值的箱中。最后,我使用原始 ID 和 ID 内 bin 来计算重置累积和。

代码:

    # imports
    from pyspark.sql.functions import col
    import pyspark.sql.functions as F
    from pyspark.sql.window import Window
    
    
    # value for resetting cummulative sum
    reset_value = 20
    
    # dummy DataFrame
    df = spark.createDataFrame([
      [1, 1, 10],
      [1, 2, 12],
      [1, 3, 5],
      [1, 4, 10],
      [2, 1, 6],
      [2, 3, 7]
    ], schema=["id", "t", "x"])
    
    # windows
    w1 = Window().partitionBy("id").orderBy("t")
    w2 = Window().partitionBy("id", "cumsum_group").orderBy("t").rangeBetween(Window.unboundedPreceding, 0)
    
    # reset AFTER reset value is reached
    df2 = (df
           .withColumn("cumsum_x", F.sum("x").over(w1.rangeBetween(Window.unboundedPreceding, 0)))
           .withColumn("cumsum_group", (col("cumsum_x") / reset_value).cast("int"))
           .withColumn("cumsum_x_reset", F.sum("x").over(w2))
          )
    
    # reset BEFORE reset value is reached
    df3 = (df
           .withColumn("cumsum_x", F.sum("x").over(w1.rangeBetween(Window.unboundedPreceding, 0)))
           .withColumn("cumsum_group", (col("cumsum_x") / reset_value).cast("int"))
           .withColumn("cumsum_group", F.lag("cumsum_group").over(w1))
           .withColumn("cumsum_group", F.when(col("cumsum_group").isNull(), 0).otherwise(col("cumsum_group")))
           .withColumn("cumsum_x_reset", F.sum("x").over(w2))
          )

输出:

+---+---+---+
| id|  t|  x|
+---+---+---+
|  1|  1| 10|
|  1|  2| 12|
|  1|  3|  5|
|  1|  4| 10|
|  2|  1|  6|
|  2|  3|  7|
+---+---+---+

+---+---+---+--------+------------+--------------+
| id|  t|  x|cumsum_x|cumsum_group|cumsum_x_reset|
+---+---+---+--------+------------+--------------+
|  1|  1| 10|      10|           0|            10|
|  1|  2| 12|      22|           1|            12|
|  1|  3|  5|      27|           1|            17|
|  1|  4| 10|      37|           1|            27|
|  2|  1|  6|       6|           0|             6|
|  2|  3|  7|      13|           0|            13|
+---+---+---+--------+------------+--------------+

+---+---+---+--------+------------+--------------+
| id|  t|  x|cumsum_x|cumsum_group|cumsum_x_reset|
+---+---+---+--------+------------+--------------+
|  1|  1| 10|      10|           0|            10|
|  1|  2| 12|      22|           0|            22|
|  1|  3|  5|      27|           1|             5|
|  1|  4| 10|      37|           1|            15|
|  2|  1|  6|       6|           0|             6|
|  2|  3|  7|      13|           0|            13|
+---+---+---+--------+------------+--------------+