我有一个Spark(2.4.0)数据框,其中的列只有两个值(0
或1
)。我需要计算此数据中连续0
和1
的条纹,如果值更改,则将条纹重置为零。
一个例子:
from pyspark.sql import (SparkSession, Window)
from pyspark.sql.functions import (to_date, row_number, lead, col)
spark = SparkSession.builder.appName('test').getOrCreate()
# Create dataframe
df = spark.createDataFrame([
('2018-01-01', 'John', 0, 0),
('2018-01-01', 'Paul', 1, 0),
('2018-01-08', 'Paul', 3, 1),
('2018-01-08', 'Pete', 4, 0),
('2018-01-08', 'John', 3, 0),
('2018-01-15', 'Mary', 6, 0),
('2018-01-15', 'Pete', 6, 0),
('2018-01-15', 'John', 6, 1),
('2018-01-15', 'Paul', 6, 1),
], ['str_date', 'name', 'value', 'flag'])
df.orderBy('name', 'str_date').show()
## +----------+----+-----+----+
## | str_date|name|value|flag|
## +----------+----+-----+----+
## |2018-01-01|John| 0| 0|
## |2018-01-08|John| 3| 0|
## |2018-01-15|John| 6| 1|
## |2018-01-15|Mary| 6| 0|
## |2018-01-01|Paul| 1| 0|
## |2018-01-08|Paul| 3| 1|
## |2018-01-15|Paul| 6| 1|
## |2018-01-08|Pete| 4| 0|
## |2018-01-15|Pete| 6| 0|
## +----------+----+-----+----+
使用此数据,我想计算连续的零和一的连胜纪录,按日期排序并按名称“显示在窗口中”:
# Expected result:
## +----------+----+-----+----+--------+--------+
## | str_date|name|value|flag|streak_0|streak_1|
## +----------+----+-----+----+--------+--------+
## |2018-01-01|John| 0| 0| 1| 0|
## |2018-01-08|John| 3| 0| 2| 0|
## |2018-01-15|John| 6| 1| 0| 1|
## |2018-01-15|Mary| 6| 0| 1| 0|
## |2018-01-01|Paul| 1| 0| 1| 0|
## |2018-01-08|Paul| 3| 1| 0| 1|
## |2018-01-15|Paul| 6| 1| 0| 2|
## |2018-01-08|Pete| 4| 0| 1| 0|
## |2018-01-15|Pete| 6| 0| 2| 0|
## +----------+----+-----+----+--------+--------+
当然,如果'flag'改变,我需要连胜将自己重置为零。
有没有办法做到这一点?
答案 0 :(得分:1)
这将要求先对具有相同值的第一组连续行使用行号方法,然后在各组之间使用排名方法。
from pyspark.sql import Window
from pyspark.sql import functions as f
#Windows definition
w1 = Window.partitionBy(df.name).orderBy(df.date)
w2 = Window.partitionBy(df.name,df.flag).orderBy(df.date)
res = df.withColumn('grp',f.row_number().over(w1)-f.row_number().over(w2))
#Window definition for streak
w3 = Window.partitionBy(res.name,res.flag,res.grp).orderBy(res.date)
streak_res = res.withColumn('streak_0',f.when(res.flag == 1,0).otherwise(f.row_number().over(w3))) \
.withColumn('streak_1',f.when(res.flag == 0,0).otherwise(f.row_number().over(w3)))
streak_res.show()