在PySpark 1.5.0中,如何根据列`x`的值列出列'y`的所有项?

时间:2016-03-20 15:06:10

标签: python apache-spark pyspark

以下问题特定于PySpark 1.5.0版,因为新功能不断添加到PySpark。

如何根据列y的值列出列x的所有项? 例如:

rdd = sc.parallelize([ {'x': "foo", 'y': 1}, 
                  {'x': "foo", 'y': 1}, 
                  {'x': "bar", 'y': 10}, 
                 {'x': "bar", 'y': 2},
                 {'x': 'qux', 'y':999}])
df = sqlCtx.createDataFrame(rdd)
df.show()

+---+---+
|  x|  y|
+---+---+
|foo|  1|
|foo|  1|
|bar| 10|
|bar|  2|
|qux|999|
+---+---+

我希望有类似的东西:

+---+--------+
|  x|  y     |
+---+--------+
|foo| [1, 1] |
|bar| [10, 2]|
|bar| [999]  |
+---+--------+

订单无关紧要。在Pandas,我可以通过以下方式实现这个组织:

pd = df.toPandas()
pd.groupby('x')['y'].apply(list).reset_index()

但是,1.5.0版中的groupBy聚合功能似乎非常有限。知道如何克服这个限制吗?

1 个答案:

答案 0 :(得分:4)

您可以使用collect_list Hive UDAF:

from pyspark.sql.functions import expr
from pyspark import HiveContext

sqlContext = HiveContext(sc)
df = sqlContext.createDataFrame(rdd)

df.groupBy("x").agg(expr("collect_list(y) AS y"))

在1.6或更高版本中,您可以使用collect_list函数:

from pyspark.sql.functions import collect_list

df.groupBy("x").agg(collect_list(y).alias("y"))

并且在2.0或更高版本中,您可以在没有Hive支持的情况下使用它。

这不是一个特别有效的操作,所以你应该适度使用它。

此外,不要使用字典进行模式推断。它自1.2以来已被弃用