我有一个DataFrame作为A,如:
+---+---+---+---+----------+
|key| c1| c2| c3| date|
+---+---+---+---+----------+
| k1| -1| 0| -1|2015-04-28|
| k1| 1| -1| 1|2015-07-28|
| k1| 1| 1| 1|2015-10-28|
| k1| 1| 1| -1|2015-12-28|
| k2| -1| 0| 1|2015-04-28|
| k2| -1| 1| -1|2015-07-28|
| k2| 1| -1| 0|2015-10-28|
| k2| 1| -1| 1|2015-11-28|
+---+---+---+---+----------+
创建A
的代码:
data = [('k1', '-1', '0', '-1','2015-04-28'),
('k1', '1', '-1', '1', '2015-07-28'),
('k1', '1', '1', '1', '2015-10-28'),
('k1', '1', '1', '-1', '2015-12-28'),
('k2', '-1', '0', '1', '2015-04-28'),
('k2', '-1', '1', '-1', '2015-07-28'),
('k2', '1', '-1', '0', '2015-10-28'),
('k2', '1', '-1', '1', '2015-11-28')]
A = spark.createDataFrame(data, ['key', 'c1', 'c2','c3','date'])
A = A.withColumn('date',A.date.cast('date'))
我想获取日期,此时第c3列的值在第一次更改(按日期顺序升序),预期结果如:
+---+---+----------+
|key| c3| date|
+---+---+----------+
| k1| 1|2015-07-28|
| k2| -1|2015-07-28|
+---+---+----------+
答案 0 :(得分:2)
这显然是窗口功能的工作:
from pyspark.sql.window import Window
from pyspark.sql.functions import col, lag, sum
# Define a window
w = Window.partitionBy("key").orderBy("date")
(A
.withColumn(
"ind",
# Compute cumulative sum of the indicator variables over window
sum(
# yield 1 if date has changed from the previous row, 0 otherwise
(lag("date", 1).over(w) != col("date")).cast("int")
).over(w))
# Date has change for the first time when cumulative sum is equal to 1
.where(col("ind") == 1))
结果:
+---+---+---+---+----------+---+
|key| c1| c2| c3| date|ind|
+---+---+---+---+----------+---+
| k2| -1| 1| -1|2015-07-28| 1|
| k1| 1| -1| 1|2015-07-28| 1|
+---+---+---+---+----------+---+
答案 1 :(得分:1)
这是我使用UDF的解决方案。
import pyspark.sql.functions as func
from pyspark.sql.types import *
data = [('k1', '-1', '0', '-1','2015-04-28'),
('k1', '1', '-1', '1', '2015-07-28'),
('k1', '1', '1', '1', '2015-10-28'),
('k2', '-1', '0', '1', '2015-04-28'),
('k2', '-1', '1', '-1', '2015-07-28'),
('k2', '1', '-1', '0', '2015-10-28')]
# note that I didn't cast date type here
A = spark.createDataFrame(data, ['key', 'c1', 'c2','c3','date'])
A_group = A.select('key', 'c3', 'date').groupby('key')
A_agg = A_group.agg(func.collect_list(func.col('c3')).alias('c3'),
func.collect_list(func.col('date')).alias('date_list'))
# UDF to return first change for given list
def find_change(c3_list, date_list):
"""return first change"""
for i in range(1, len(c3_list)):
if c3_list[i] != c3_list[i-1]:
return [c3_list[i], date_list[i]]
else:
return None
udf_find_change = func.udf(find_change, returnType=ArrayType(StringType()))
# find first change given
A_first_change = A_agg.select('key', udf_find_change(func.col('c3'), func.col('date_list')).alias('first_change'))
A_first_change.select('key',
func.col('first_change').getItem(0).alias('c3'),
func.col('first_change').getItem(1).alias('date').cast('date').show()
<强>输出强>
+---+---+----------+
|key| c3| date|
+---+---+----------+
| k2| -1|2015-07-28|
| k1| 1|2015-07-28|
+---+---+----------+