如何按多列分组并在PySpark中列表?

时间:2017-10-03 07:10:51

标签: apache-spark pyspark apache-spark-sql pyspark-sql

这是我的问题: 我有这个RDD:

a = [[u'PNR1',u'TKT1',u'TEST',u'a2',u'a3'],[u'PNR1',u'TKT1',u'TEST',u'a5',u'a6'],[u'PNR1',u'TKT1',u'TEST',u'a8',u'a9']]

rdd= sc.parallelize (a)

然后我尝试:

rdd.map(lambda x: (x[0],x[1],x[2], list(x[3:])))

.toDF(["col1","col2","col3","col4"])

.groupBy("col1","col2","col3")

.agg(collect_list("col4")).show

最后我应该找到这个:

[col1,col2,col3,col4]=[u'PNR1',u'TKT1',u'TEST',[[u'a2',u'a3'][u'a5',u'a6'][u'a8',u'a9']]]

但问题是我无法收集清单。

如果有人能帮助我,我会很感激

3 个答案:

答案 0 :(得分:1)

这可能会完成你的工作(或者给你一些进一步的想法)......

一个想法是将您的col4转换为原始数据类型,即字符串:

from pyspark.sql.functions import collect_list
import pandas as pd

a = [[u'PNR1',u'TKT1',u'TEST',u'a2',u'a3'],[u'PNR1',u'TKT1',u'TEST',u'a5',u'a6'],[u'PNR1',u'TKT1',u'TEST',u'a8',u'a9']]
rdd = sc.parallelize(a)

df = rdd.map(lambda x: (x[0],x[1],x[2], '(' + ' '.join(str(e) for e in x[3:]) + ')')).toDF(["col1","col2","col3","col4"])

df.groupBy("col1","col2","col3").agg(collect_list("col4")).toPandas().values.tolist()[0]
#[u'PNR1', u'TKT1', u'TEST', [u'(a2 a3)', u'(a5 a6)', u'(a8 a9)']]

更新(在您自己的回答之后):

我真的认为我上面达到的目的足以根据你的需要进一步调整它,而且我现在没有时间自己做这件事;所以,在这里(在修改我的df定义以去掉括号之后,它只是一个列表理解的问题):

df = rdd.map(lambda x: (x[0],x[1],x[2], ' '.join(str(e) for e in x[3:]))).toDF(["col1","col2","col3","col4"])

# temp list:
ff = df.groupBy("col1","col2","col3").agg(collect_list("col4")).toPandas().values.tolist()[0]
ff
# [u'PNR1', u'TKT1', u'TEST', [u'a2 a3', u'a5 a6', u'a8 a9']]

# final list of lists:
ll = ff[:-1] + [[x.split(' ') for x in ff[-1]]]
ll

给出了您最初请求的结果:

[u'PNR1', u'TKT1', u'TEST', [[u'a2', u'a3'], [u'a5', u'a6'], [u'a8', u'a9']]]  # requested output

与您自己的答案中提供的方法相比,此方法具有某些优势

  • 它避免了Pyspark UDF,known to be slow
  • 所有处理都在最终(希望小得多)聚合数据中完成,而不是添加和删除列,并在初始(可能更大)的数据中执行地图函数和UDF

答案 1 :(得分:1)

我终于找到了一个解决方案,这不是最好的方法,但我可以继续工作......

[(u'PNR1', u'TKT1', u'TEST', [[u'a2', u'a3'], [u'a5', u'a6'], [u'a8', u'a9']])]

它给出了:

{{1}}

希望这个解决方案可以帮助其他人。

感谢您的所有答案。

答案 2 :(得分:0)

由于您无法更新到2.x,因此您唯一的选择是RDD API。用以下内容替换当前代码:

rdd.map(lambda x: ((x[0], x[1], x[2]), list(x[3:]))).groupByKey().toDF()