SparkSQL在分组后从数据帧之前和之后获取行

时间:2016-01-21 01:33:51

标签: sql scala apache-spark apache-spark-sql window-functions

鉴于此Dataframe df

 +-----------+--------------------+-------------+-------+
|CustNumb   |        PurchaseDate|     price| activeFlag|
+-----------+--------------------+-------------+-------+
|          3|2013-07-17 00:00:...|         17.9|    0|
|          3|2013-08-27 00:00:...|        61.13|    0|
|          3|2013-08-28 00:00:...|        25.07|    1|
|          3|2013-08-29 00:00:...|        24.23|    0|
|          3|2013-09-06 00:00:...|         3.94|    0|
|         20|2013-02-28 00:00:...|       354.64|    0|
|         20|2013-04-07 00:00:...|         15.0|    0|
|         20|2013-05-10 00:00:...|        545.0|    0|
|         28|2013-02-17 00:00:...|        190.0|    0|
|         28|2013-04-08 00:00:...|         20.0|    0|
|         28|2013-04-16 00:00:...|         89.0|    0|
|         28|2013-05-18 00:00:...|        260.0|    0|
|         28|2013-06-06 00:00:...|       586.57|    1|
|         28|2013-06-09 00:00:...|        250.0|    0|

我想得到的结果是,当它找到一个非活动标志时,返回按购买日期排序前后2行价格的平均值' 1'。这是我要找的结果:

+-----------+--------------------+-------------+-------+---------------+
|CustNumb   |        PurchaseDate|     price| activeFlag| OutputVal |
+-----------+--------------------+-------------+-------+------------+
|          3|2013-07-17 00:00:...|         17.9|    0|   17.9
|          3|2013-08-27 00:00:...|        61.13|    0|   61.13
|          3|2013-08-28 00:00:...|        25.07|    1|   26.8 (avg of 2 prices before and 2 after)
|          3|2013-08-29 00:00:...|        24.23|    0|   24.23
|          3|2013-09-06 00:00:...|         3.94|    0|   3.94

|         20|2013-02-28 00:00:...|       354.64|    0|   354.64
|         20|2013-04-07 00:00:...|         15.0|    0|   15.0
|         20|2013-05-10 00:00:...|        545.0|    0|   545.0

|         28|2013-02-17 00:00:...|        190.0|    0|   190.0
|         28|2013-04-08 00:00:...|         20.0|    0|   20.0
|         28|2013-04-16 00:00:...|         89.0|    0|   89.0
|         28|2013-05-18 00:00:...|        260.0|    0|   260.0
|         28|2013-06-06 00:00:...|       586.57|    1|   199.6 (avg of 2 prices before and 1 after)
|         28|2013-06-09 00:00:...|        250.0|    0|   250

在上面的例子中,对于custNum 3和28,我有activeFlag 1,所以如果它存在相同的custNumb,我需要计算前后2行的平均值。

我正在考虑在数据框架上使用窗口函数,但是无法获得任何好的想法来解决这个问题,因为我不太喜欢火花编程

val w = Window.partitionBy("CustNumb").orderBy("PurchaseDate")

我如何实现这一点,是否可以通过Window功能或任何更好的方法实现?

2 个答案:

答案 0 :(得分:0)

如果您已经有窗口,那么这样的简单条件应该可以正常工作:

val cond = ($"activeFlag" === 1) && (lag($"activeFlag", 1).over(w) === 0)

// Windows covering rows before and after
val before = w.rowsBetween(-2, -1)
val after = w.rowsBetween(1, 2)

// Expression with sum of rows and number of rows 
val sumPrice = sum($"price").over(before) + sum($"price").over(after)
val countPrice = sum($"ones_").over(before) + sum($"ones_").over(after)

val expr = when(cond, sumPrice / countPrice).otherwise($"price")

df.withColumn("ones_", lit(1)).withColumn("outputVal", expr)

答案 1 :(得分:0)

感谢Zero323。你摇滚! 这是我的代码片段,基于您的帮助我修改了以获取我在结果中寻找的数据:

 val windw = Window.partitionBy("CustNumb").orderBy("PurchaseDate")
 val cond = ($"activeFlag" === 1) //&& (lag($"activeFlag", 1).over(win) === 0)
 val avgprice = (lag($"price", 1).over(windw)  + lag($"price", 2).over(windw) + lead($"price", 1).over(windw)  + lead($"price", 2).over(windw)) / 4.0
 val expr = when(cond, avgprice).otherwise($"price")
 val finalresult = df.withColumn("newPrice", expr)

我唯一想知道的是,如果activeflag = 1存在于上面的行中,那么我想在activeflag = 1的行上方多行一行。如果我找到解决方法,我会尝试更新。