说我有一个包含杂志订阅信息的数据框:
subscription_id user_id created_at expiration_date
12384 1 2018-08-10 2018-12-10
83294 1 2018-06-03 2018-10-03
98234 1 2018-04-08 2018-08-08
24903 2 2018-05-08 2018-07-08
32843 2 2018-03-25 2018-05-25
09283 2 2018-04-07 2018-06-07
现在,我想添加一列,指出用户在此当前订阅开始之前已过期的先前订阅数量。换句话说,与给定用户关联的到期日期早于该订阅的开始日期。这是完整的期望输出:
subscription_id user_id created_at expiration_date previous_expired
12384 1 2018-08-10 2018-12-10 1
83294 1 2018-06-03 2018-10-03 0
98234 1 2018-04-08 2018-08-08 0
24903 2 2018-05-08 2018-07-08 2
32843 2 2018-03-25 2018-05-03 1
09283 2 2018-01-25 2018-02-25 0
尝试:
编辑:使用Python尝试了各种延迟/超前/等,我现在认为这是一个SQL问题
df = df.withColumn('shiftlag', func.lag(df.expires_at).over(Window.partitionBy('user_id').orderBy('created_at')))
<---编辑,编辑:没关系,这不起作用
我想我用尽了滞后/超前/移位方法,却发现它不起作用。我现在认为最好使用Spark SQL进行此操作,也许用case when
来生成新列,再结合having
count
(按ID分组)?>
答案 0 :(得分:0)
使用PySpark弄清楚了:
我首先创建了另一列,其中包含每个用户的所有到期日期的数组:
joined_array = df.groupBy('user_id').agg(collect_set('expiration_date'))
然后将该数组重新加入原始数据框:
joined_array = joined_array.toDF('user_idDROP', 'expiration_date_array')
df = df.join(joined_array, df.user_id == joined_array.user_idDROP, how = 'left').drop('user_idDROP')
然后创建了一个遍历数组的函数,如果创建的日期大于到期日期,则将1加到计数中:
def check_expiration_count(created_at, expiration_array):
if not expiration_array:
return 0
else:
count = 0
for i in expiration_array:
if created_at > i:
count += 1
return count
check_expiration_count = udf(check_expiration_count, IntegerType())
然后应用该函数创建一个具有正确计数的新列:
df = df.withColumn('count_of_subs_ending_before_creation', check_expiration_count(df.created_at, df.expiration_array))
Wala。做完了谢谢大家(没人帮忙,还是要谢谢)。希望有人会在2022年发现它有用