Spark SQL分区依据,窗口,顺序依据,计数

时间:2018-12-11 18:16:28

标签: python mysql sql sql-server python-2.7

说我有一个包含杂志订阅信息的数据框:

subscription_id    user_id       created_at       expiration_date
 12384               1           2018-08-10        2018-12-10
 83294               1           2018-06-03        2018-10-03
 98234               1           2018-04-08        2018-08-08
 24903               2           2018-05-08        2018-07-08
 32843               2           2018-03-25        2018-05-25
 09283               2           2018-04-07        2018-06-07

现在,我想添加一列,指出用户在此当前订阅开始之前已过期的先前订阅数量。换句话说,与给定用户关联的到期日期早于该订阅的开始日期。这是完整的期望输出:

subscription_id    user_id       created_at       expiration_date   previous_expired
 12384               1           2018-08-10        2018-12-10          1
 83294               1           2018-06-03        2018-10-03          0
 98234               1           2018-04-08        2018-08-08          0
 24903               2           2018-05-08        2018-07-08          2
 32843               2           2018-03-25        2018-05-03          1
 09283               2           2018-01-25        2018-02-25          0

尝试:

编辑:使用Python尝试了各种延迟/超前/等,我现在认为这是一个SQL问题

df = df.withColumn('shiftlag', func.lag(df.expires_at).over(Window.partitionBy('user_id').orderBy('created_at')))

<---编辑,编辑:没关系,这不起作用

我想我用尽了滞后/超前/移位方法,却发现它不起作用。我现在认为最好使用Spark SQL进行此操作,也许用case when来生成新列,再结合having count(按ID分组)?

1 个答案:

答案 0 :(得分:0)

使用PySpark弄清楚了:

我首先创建了另一列,其中包含每个用户的所有到期日期的数组:

joined_array = df.groupBy('user_id').agg(collect_set('expiration_date'))

然后将该数组重新加入原始数据框:

joined_array = joined_array.toDF('user_idDROP', 'expiration_date_array')
df = df.join(joined_array, df.user_id == joined_array.user_idDROP, how = 'left').drop('user_idDROP')

然后创建了一个遍历数组的函数,如果创建的日期大于到期日期,则将1加到计数中:

def check_expiration_count(created_at, expiration_array):
  if not expiration_array:
    return 0
  else:
   count = 0
    for i in expiration_array:
  if created_at > i:
    count += 1
return count

check_expiration_count = udf(check_expiration_count, IntegerType())

然后应用该函数创建一个具有正确计数的新列:

df = df.withColumn('count_of_subs_ending_before_creation', check_expiration_count(df.created_at, df.expiration_array))

Wala。做完了谢谢大家(没人帮忙,还是要谢谢)。希望有人会在2022年发现它有用