我有一个包含列的数据集:id,timestamp,x,y
[id],[timestamp],[x],[y]
0,1443489380,100,1
0,1443489390,200,0
0,1443489400,300,0
0,1443489410,400,1
我定义了一个窗口规范:w = Window.partitionBy("id").orderBy("timestamp")
我想做这样的事情。创建一个新列,将当前行的x与下一行的x相加。
如果sum> = 500,则设置new column = BIG else SMALL。
df = df.withColumn("newCol",
when(df.x + lag(df.x,-1).over(w) >= 500 , "BIG")
.otherwise("SMALL") )
但是,我想在执行此之前过滤数据而不影响原始df 。
[只有y = 1的行才能应用上述代码]
因此,将在代码上方应用的数据仅为这两行。
0,1443489380,100,1
0,1443489410,400,1
我已经这样做但是太糟糕了。
df2 = df.filter(df.y == 1)
df2 = df2.withColumn("newCol",
when(df.x + lag(df.x,-1).over(w) >= 500 , "BIG")
.otherwise("SMALL") )
df = df.join(df2, ["id","timestamp"], "outer")
我想做这样的事情,但它不可能,因为它会导致AttributeError:' DataFrame'对象没有属性'当'
df = df.withColumn("newCol", df.filter(df.y == 1)
.when(df.x + lag(df.x,-1).over(w) >= 500 , "BIG")
.otherwise("SMALL") )
总之,我只想对sum x和next x之前的y = 1的行进行临时过滤。
答案 0 :(得分:5)
你的代码工作正常,我认为你导入功能模块。试过你的代码,
>>> from pyspark.sql import functions as F
>>> df2 = df2.withColumn("newCol",
F.when((df.x + F.lag(df.x,-1).over(w))>= 500 , "BIG")
.otherwise("SMALL") )
>>> df2.show()
+---+----------+---+---+------+
| id| timestamp| x| y|newCol|
+---+----------+---+---+------+
| 0|1443489380|100| 1| BIG|
| 0|1443489410|400| 1| SMALL|
+---+----------+---+---+------+
编辑: 尝试过根据'id','y'列更改窗口分区,
>>> w = Window.partitionBy("id","y").orderBy("timestamp")
>>> df.select("*", F.when(df.y == 1,F.when((df.x+F.lag("x",-1).over(w)) >=500,'BIG').otherwise('SMALL')).otherwise(None).alias('new_col')).show()
+---+----------+---+---+-------+
| id| timestamp| x| y|new_col|
+---+----------+---+---+-------+
| 0|1443489380|100| 1| BIG|
| 0|1443489410|400| 1| SMALL|
| 0|1443489390|200| 0| null|
| 0|1443489400|300| 0| null|
+---+----------+---+---+-------+