pyspark:collect_list()如何存储每行中组剩余元素的列表

时间:2018-08-28 11:56:07

标签: pyspark

我的数据集通过collect_list()函数按两个变量分组:“ customer”和“ sku”,我希望每行存储一个变量,该变量在组中还剩下多少行。

我有以下输出:

+----------+--------------------+-----------+--------------------+---+
|  customer|                 sku|auto_create|        next_creates|  n|
+----------+--------------------+-----------+--------------------+---+
|248274_ARC|J31/H01N2-D35MM2-...|          Y|           [Y, Y, Y]|  3|
|248274_ARC|J31/H01N2-D35MM2-...|          Y|           [Y, Y, Y]|  3|
|248274_ARC|J31/H01N2-D35MM2-...|          Y|           [Y, Y, Y]|  3|
|297945_ARC|  F87/012V55WH31EX10|          Y|        [Y, Y, Y, Y]|  4|
|297945_ARC|  F87/012V55WH31EX10|          Y|        [Y, Y, Y, Y]|  4|
|297945_ARC|  F87/012V55WH31EX10|          Y|        [Y, Y, Y, Y]|  4|
|297945_ARC|  F87/012V55WH31EX10|          Y|        [Y, Y, Y, Y]|  4|
|318725_ARC|          605/85524V|          N|           [N, N, N]|  3|
|318725_ARC|          605/85524V|          N|           [N, N, N]|  3|
|318725_ARC|          605/85524V|          N|           [N, N, N]|  3|
|403787_ARC|     BPC/77/9601-136|          N|  [N, N, N, Y, Y, Y]|  6|
|403787_ARC|     BPC/77/9601-136|          N|  [N, N, N, Y, Y, Y]|  6|
|403787_ARC|     BPC/77/9601-136|          N|  [N, N, N, Y, Y, Y]|  6|
|403787_ARC|     BPC/77/9601-136|          Y|  [N, N, N, Y, Y, Y]|  6|
|403787_ARC|     BPC/77/9601-136|          Y|  [N, N, N, Y, Y, Y]|  6|
|403787_ARC|     BPC/77/9601-136|          Y|  [N, N, N, Y, Y, Y]|  6|
|434238_ARC|        BB8/40300142|          Y|        [Y, Y, Y, Y]|  4|
|434238_ARC|        BB8/40300142|          Y|        [Y, Y, Y, Y]|  4|
|434238_ARC|        BB8/40300142|          Y|        [Y, Y, Y, Y]|  4|
|434238_ARC|        BB8/40300142|          Y|        [Y, Y, Y, Y]|  4|
+----------+--------------------+-----------+--------------------+---+

我想要这个输出:

+----------+--------------------+-----------+--------------------+---+
|  customer|                 sku|auto_create|        next_creates|  n|
+----------+--------------------+-----------+--------------------+---+
|248274_ARC|J31/H01N2-D35MM2-...|          Y|           [Y, Y, Y]|  3|
|248274_ARC|J31/H01N2-D35MM2-...|          Y|              [Y, Y]|  3|
|248274_ARC|J31/H01N2-D35MM2-...|          Y|                 [Y]|  3|
|297945_ARC|  F87/012V55WH31EX10|          Y|        [Y, Y, Y, Y]|  4|
|297945_ARC|  F87/012V55WH31EX10|          Y|           [Y, Y, Y]|  4|
|297945_ARC|  F87/012V55WH31EX10|          Y|              [Y, Y]|  4|
|297945_ARC|  F87/012V55WH31EX10|          Y|                 [Y]|  4|
|318725_ARC|          605/85524V|          N|           [N, N, N]|  3|
|318725_ARC|          605/85524V|          N|              [N, N]|  3|
|318725_ARC|          605/85524V|          N|                 [N]|  3|
|403787_ARC|     BPC/77/9601-136|          N|  [N, N, N, Y, Y, Y]|  6|
|403787_ARC|     BPC/77/9601-136|          N|     [N, N, Y, Y, Y]|  6|
|403787_ARC|     BPC/77/9601-136|          N|        [N, Y, Y, Y]|  6|
|403787_ARC|     BPC/77/9601-136|          Y|           [Y, Y, Y]|  6|
|403787_ARC|     BPC/77/9601-136|          Y|              [Y, Y]|  6|
|403787_ARC|     BPC/77/9601-136|          Y|                 [Y]|  6|
|434238_ARC|        BB8/40300142|          Y|        [Y, Y, Y, Y]|  4|
|434238_ARC|        BB8/40300142|          Y|           [Y, Y, Y]|  4|
|434238_ARC|        BB8/40300142|          Y|              [Y, Y]|  4|
|434238_ARC|        BB8/40300142|          Y|                 [Y]|  4|
+----------+--------------------+-----------+--------------------+---+

我正在使用以下代码:

w = \
Window.partitionBy('customer','sku').orderBy('customer','sku')
analysis = analysis \
    .withColumn('next_creates', collect_list('auto_create').over(w)) 

来自尝试加入的错误答案的建议:

analysis = analysis.withColumn('rownumber',row_number().over(w).alias('rownumber'))

df1 = analysis
df2 = analysis

df1.join(df2, (df1.customer == df2.customer) & (df1.sku == df2.sku) & (df1.rownumber <= df2.rownumber)).groupBy('customer', 'sku').agg(collect_list('auto_create'))

3 个答案:

答案 0 :(得分:0)

如果您可以引入“行号”列,则可以执行如下查询。下面的伪代码(未经测试):

df.alias('df1').join(df.alias('df2'), 
  on=
     col('df1.customer')==col('df2.customer') 
     && col('df1.sku')==col('df2.sku') 
     && col('df1.rownum') <=col('df2.rownum') 
 )
.groupBy('df1.customer', 'df1.sku', 'df1.auto_create')
.agg(collect_list('df2.auto_create'))

答案 1 :(得分:0)

解决方案:

我添加了“行号”列:

df = df.withColumn('rownumber',row_number().over(w).alias('rownumber'))

然后我如下更改了“ next_creates”的长度:

df = df.withColumn('next_creates', df.next_creates[rownumber-1:])

答案 2 :(得分:0)

如果您对更像火花的解决方案感兴趣: 您可以在Window上使用rowsBetween,并具有以下内容:

downloadTicket

,输出为:

#create a test dataframe
test_df = spark.createDataFrame([
    ("318725_ARC","605/85524V","N"), ("318725_ARC","605/85524V","N"),
    ("318725_ARC","605/85524V","N"),("403787_ARC","BPC/77/9601-136","N"),
    ("403787_ARC","BPC/77/9601-136","N"),("403787_ARC","BPC/77/9601-136","N"),
    ("403787_ARC","BPC/77/9601-136","Y"),("403787_ARC","BPC/77/9601-136","Y"),
    ("403787_ARC","BPC/77/9601-136","Y")], ("customer", "sku","auto_create"))

w = Window.partitionBy('customer','sku').orderBy('customer','sku').rowsBetween(0,Window.unboundedFollowing)
analysis = test_df.withColumn('next_creates',collect_list('auto_create').over(w)) 

analysis.show()