在同一分组中添加具有以上所有行之和的列

时间:2019-07-25 19:35:49

标签: scala apache-spark pyspark apache-spark-sql

我需要创建一个“滚动计数”列,该列将获取以前的计数并添加每天和公司的新计数。我已经将数据框组织并分类为每个公司的升序日期组以及相应的计数。我还添加了一个“ ix”列,该列为每个分组编制了索引,如下所示:

+--------------------+--------------------+-----+---+
|     Normalized_Date|             company|count| ix|
+--------------------+--------------------+-----+---+
|09/25/2018 00:00:...|[5c40c8510fb7c017...|    7|  1|
|09/25/2018 00:00:...|[5bdb2b543951bf07...|    9|  1|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...|    7|  1|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...|   60|  2|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...|    1|  3|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...|    9|  4|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...|   29|  5|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...|   42|  6|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...|  317|  7|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...|    3|  8|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...|   15|  9|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...|    1| 10|
+--------------------+--------------------+-----+---+

我需要的输出只是将每个公司截至该日期的所有计数加起来。像这样:

+--------------------+--------------------+-----+---+------------+
|     Normalized_Date|             company|count| ix|RollingCount|
+--------------------+--------------------+-----+---+------------+
|09/25/2018 00:00:...|[5c40c8510fb7c017...|    7|  1|           7|
|09/25/2018 00:00:...|[5bdb2b543951bf07...|    9|  1|           9|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...|    7|  1|           7|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...|   60|  2|          67|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...|    1|  3|          68|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...|    9|  4|          77|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...|   29|  5|         106|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...|   42|  6|         148|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...|  317|  7|         465|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...|    3|  8|         468|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...|   15|  9|         483|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...|    1| 10|         484|
+--------------------+--------------------+-----+---+------------+

我认为lag函数会有用,并且我能够使用ix> 1来获取rollingcount的每一行,并使用以下代码在其上方直接添加计数:

w = Window.partitionBy('company').orderBy(F.unix_timestamp('Normalized_Dat e','MM/dd/yyyy HH:mm:ss aaa').cast('timestamp'))
refined_DF = solutionDF.withColumn("rn", F.row_number().over(w))
solutionDF = refined_DF.withColumn('RollingCount',F.when(refined_DF['rn'] > 1, refined_DF['count'] + F.lag(refined_DF['count'],count= 1 ).over(w)).otherwise(refined_DF['count']))

会产生以下df:

+--------------------+--------------------+-----+---+------------+
|     Normalized_Date|             company|count| ix|RollingCount|
+--------------------+--------------------+-----+---+------------+
|09/25/2018 00:00:...|[5c40c8510fb7c017...|    7|  1|           7|
|09/25/2018 00:00:...|[5bdb2b543951bf07...|    9|  1|           9|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...|    7|  1|           7|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...|   60|  2|          67|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...|    1|  3|          61|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...|    9|  4|          10|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...|   29|  5|          38|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...|   42|  6|          71|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...|  317|  7|         359|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...|    3|  8|         320|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...|   15|  9|          18|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...|    1| 10|          16|
+--------------------+--------------------+-----+---+------------+

我只需要它来加总上面ix行的所有计数。我尝试使用udf找出lag函数中的“计数”输入,但是我不断收到“'列'对象不可调用”错误,而且它不执行所有行的总和。我也尝试过使用循环,但这似乎是不可能的,因为它每次都会创建一个新的数据帧,而且以后我需要将它们全部加入。必须有一个更简单的方法来执行此操作。也许与滞后功能不同?

2 个答案:

答案 0 :(得分:1)

lag返回您当前值之前的某个单行,但是您需要一个范围来计算累积和。因此,您必须使用窗口函数rangeBetween(rowsBetween)。看下面的例子:

import pyspark.sql.functions as F
from pyspark.sql import Window

l = [
('09/25/2018', '5c40c8510fb7c017',  7,  1),
('09/25/2018', '5bdb2b543951bf07',    9,  1),
('11/28/2017', '593b0d9f3f21f9dd',     7,  1),
('11/29/2017', '593b0d9f3f21f9dd',    60,  2),
('01/09/2018', '593b0d9f3f21f9dd',     1,  3),
('04/27/2018', '593b0d9f3f21f9dd',     9,  4),
('09/25/2018', '593b0d9f3f21f9dd',    29,  5),
('11/20/2018', '593b0d9f3f21f9dd',    42,  6),
('12/11/2018', '593b0d9f3f21f9dd',   317,  7),
('01/04/2019', '593b0d9f3f21f9dd',     3,  8),
('02/13/2019', '593b0d9f3f21f9dd',    15,  9),
('04/01/2019', '593b0d9f3f21f9dd',     1, 10)
]

columns = ['Normalized_Date', 'company','count', 'ix']

df=spark.createDataFrame(l, columns)

df = df.withColumn('Normalized_Date', F.to_date(df.Normalized_Date, 'MM/dd/yyyy'))

w = Window.partitionBy('company').orderBy('Normalized_Date').rangeBetween(Window.unboundedPreceding, 0)

df = df.withColumn('Rolling_count', F.sum('count').over(w))
df.show()

输出:

+---------------+----------------+-----+---+-------------+ 
|Normalized_Date|         company|count| ix|Rolling_count| 
+---------------+----------------+-----+---+-------------+ 
|     2018-09-25|5c40c8510fb7c017|    7|  1|            7| 
|     2018-09-25|5bdb2b543951bf07|    9|  1|            9| 
|     2017-11-28|593b0d9f3f21f9dd|    7|  1|            7| 
|     2017-11-29|593b0d9f3f21f9dd|   60|  2|           67| 
|     2018-01-09|593b0d9f3f21f9dd|    1|  3|           68| 
|     2018-04-27|593b0d9f3f21f9dd|    9|  4|           77| 
|     2018-09-25|593b0d9f3f21f9dd|   29|  5|          106| 
|     2018-11-20|593b0d9f3f21f9dd|   42|  6|          148| 
|     2018-12-11|593b0d9f3f21f9dd|  317|  7|          465| 
|     2019-01-04|593b0d9f3f21f9dd|    3|  8|          468| 
|     2019-02-13|593b0d9f3f21f9dd|   15|  9|          483| 
|     2019-04-01|593b0d9f3f21f9dd|    1| 10|          484| 
+---------------+----------------+-----+---+-------------+

答案 1 :(得分:0)

尝试一下。 您需要窗口框架中当前行的所有先前行的总和。

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.expressions.WindowSpec
import org.apache.spark.sql.functions._

val df = Seq(
("5c40c8510fb7c017", 7, 1),
("5bdb2b543951bf07", 9, 1),
("593b0d9f3f21f9dd", 7, 1),
("593b0d9f3f21f9dd", 60, 2),
("593b0d9f3f21f9dd", 1, 3),
("593b0d9f3f21f9dd", 9, 4),
("593b0d9f3f21f9dd", 29, 5),
("593b0d9f3f21f9dd", 42, 6),
("593b0d9f3f21f9dd", 317, 7),
("593b0d9f3f21f9dd", 3, 8),
("593b0d9f3f21f9dd", 15, 9),
("593b0d9f3f21f9dd", 1, 10)
).toDF("company", "count", "ix")

scala> df.show(false)
+----------------+-----+---+
|company         |count|ix |
+----------------+-----+---+
|5c40c8510fb7c017|7    |1  |
|5bdb2b543951bf07|9    |1  |
|593b0d9f3f21f9dd|7    |1  |
|593b0d9f3f21f9dd|60   |2  |
|593b0d9f3f21f9dd|1    |3  |
|593b0d9f3f21f9dd|9    |4  |
|593b0d9f3f21f9dd|29   |5  |
|593b0d9f3f21f9dd|42   |6  |
|593b0d9f3f21f9dd|317  |7  |
|593b0d9f3f21f9dd|3    |8  |
|593b0d9f3f21f9dd|15   |9  |
|593b0d9f3f21f9dd|1    |10 |
+----------------+-----+---+


scala> val overColumns = Window.partitionBy("company").orderBy("ix").rowsBetween(Window.unboundedPreceding, Window.currentRow)
overColumns: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@3ed5e17c

scala> val outputDF = df.withColumn("RollingCount", sum("count").over(overColumns))
outputDF: org.apache.spark.sql.DataFrame = [company: string, count: int ... 2 more fields]

scala> outputDF.show(false)
+----------------+-----+---+------------+
|company         |count|ix |RollingCount|
+----------------+-----+---+------------+
|5c40c8510fb7c017|7    |1  |7           |
|5bdb2b543951bf07|9    |1  |9           |
|593b0d9f3f21f9dd|7    |1  |7           |
|593b0d9f3f21f9dd|60   |2  |67          |
|593b0d9f3f21f9dd|1    |3  |68          |
|593b0d9f3f21f9dd|9    |4  |77          |
|593b0d9f3f21f9dd|29   |5  |106         |
|593b0d9f3f21f9dd|42   |6  |148         |
|593b0d9f3f21f9dd|317  |7  |465         |
|593b0d9f3f21f9dd|3    |8  |468         |
|593b0d9f3f21f9dd|15   |9  |483         |
|593b0d9f3f21f9dd|1    |10 |484         |
+----------------+-----+---+------------+