下面是我拥有的数据框
df = sqlContext.createDataFrame(
[("0", "0"), ("1", "2"), ("2", "3"), ("3", "4"), ("4", "0"), ("5", "5"), ("6", "5")],
["id", "value"])
+---+-----+
| id|value|
+---+-----+
| 0| 0|
| 1| 2|
| 2| 3|
| 3| 4|
| 4| 0|
| 5| 5|
| 6| 5|
+---+-----+
我想得到的是:
+---+-----+---+-----+
| id|value|masterid|partsum|
+---+-----|---+-----+
| 0| 0| 0| 0|
| 1| 2| 0| 2|
| 2| 3| 0| 5|
| 3| 4| 0| 9|
| 4| 0| 4| 0|
| 5| 5| 4| 5|
| 6| 5| 4| 10|
+---+-----+---+-----+
所以我尝试使用SparkSQL这样做:
df=df.withColumn("masterid", F.when( df.value !=0 , F.lag(df.id)).otherwise(df.id))
我本来以为lag函数可以帮助我在下一次迭代之前进行处理,从而获得masterid col。不幸的是,在我查看了手册之后,它还是无济于事。
那么,我想问一下是否有任何特殊功能可以用来做我想做的事情?还是我可以使用任何“条件滞后”功能?这样,当我看到非零项目时,可以使用滞后直到找到零数字?
答案 0 :(得分:1)
IIUC,您可以尝试定义一个子组标签(以下代码中的g
)和两个窗口规范:
from pyspark.sql import Window, functions as F
w1 = Window.orderBy('id')
w2 = Window.partitionBy('g').orderBy('id')
df.withColumn('g', F.sum(F.expr('if(value=0,1,0)')).over(w1)).select(
'id'
, 'value'
, F.first('id').over(w2).alias('masterid')
, F.sum('value').over(w2).alias('partsum')
).show()
#+---+-----+--------+-------+
#| id|value|masterid|partsum|
#+---+-----+--------+-------+
#| 0| 0| 0| 0.0|
#| 1| 2| 0| 2.0|
#| 2| 3| 0| 5.0|
#| 3| 4| 0| 9.0|
#| 4| 0| 4| 0.0|
#| 5| 5| 4| 5.0|
#| 6| 5| 4| 10.0|
#+---+-----+--------+-------+