pyspark collect_set在groupby之外的列

时间:2019-11-07 20:12:16

标签: group-by pyspark set collect

我正在尝试使用collect_set获取groupby的 NOT 部分的categorie_name字符串列表。 我的代码是

from pyspark import SparkContext
from pyspark.sql import HiveContext
from pyspark.sql import functions as F

sc = SparkContext("local")
sqlContext = HiveContext(sc)
df = sqlContext.createDataFrame([
     ("1", "cat1", "Dept1", "product1", 7),
     ("2", "cat2", "Dept1", "product1", 100),
     ("3", "cat2", "Dept1", "product2", 3),
     ("4", "cat1", "Dept2", "product3", 5),
    ], ["id", "category_name", "department_id", "product_id", "value"])

df.show()
df.groupby("department_id", "product_id")\
    .agg({'value': 'sum'}) \
    .show()

#            .agg( F.collect_set("category_name"))\

输出为

+---+-------------+-------------+----------+-----+
| id|category_name|department_id|product_id|value|
+---+-------------+-------------+----------+-----+
|  1|         cat1|        Dept1|  product1|    7|
|  2|         cat2|        Dept1|  product1|  100|
|  3|         cat2|        Dept1|  product2|    3|
|  4|         cat1|        Dept2|  product3|    5|
+---+-------------+-------------+----------+-----+

+-------------+----------+----------+
|department_id|product_id|sum(value)|
+-------------+----------+----------+
|        Dept1|  product2|         3|
|        Dept1|  product1|       107|
|        Dept2|  product3|         5|
+-------------+----------+----------+

我想要这个输出

+-------------+----------+----------+----------------------------+
|department_id|product_id|sum(value)| collect_list(category_name)|
+-------------+----------+----------+----------------------------+
|        Dept1|  product2|         3|  cat2                      |
|        Dept1|  product1|       107|  cat1, cat2                |
|        Dept2|  product3|         5|  cat1                      |
+-------------+----------+----------+----------------------------+

尝试1

df.groupby("department_id", "product_id")\
    .agg({'value': 'sum'}) \
    .agg(F.collect_set("category_name")) \
    .show()

我收到此错误:

  

pyspark.sql.utils.AnalysisException:“无法解析'category_name'”   给定的输入列:[department_id,product_id,   sum(value)] ;; \ n'聚合[collect_set('category_name,0,0)AS   collect_set(类别名称)#35] \ n +-总计[department_id#2,   product_id#3],[department_id#2,product_id#3,sum(value#4L)AS   sum(value)#24L] \ n +-LogicalRDD [id#0,category_name#1,   department_id#2,product_id#3,value#4L] \ n“

尝试2 ,我将category_name列为groupby的一部分

df.groupby("category_name", "department_id", "product_id")\
    .agg({'value': 'sum'}) \
    .agg(F.collect_set("category_name")) \
    .show()

可以,但是输出不正确

+--------------------------+
|collect_set(category_name)|
+--------------------------+
|              [cat1, cat2]|
+--------------------------+

1 个答案:

答案 0 :(得分:2)

您可以specify multiple aggregations within one agg()。适用于您的情况的正确语法为:

df.groupby("department_id", "product_id")\
    .agg(F.sum('value'), F.collect_set("category_name"))\
    .show()
#+-------------+----------+----------+--------------------------+
#|department_id|product_id|sum(value)|collect_set(category_name)|
#+-------------+----------+----------+--------------------------+
#|        Dept1|  product2|         3|                    [cat2]|
#|        Dept1|  product1|       107|              [cat1, cat2]|
#|        Dept2|  product3|         5|                    [cat1]|
#+-------------+----------+----------+--------------------------+

您的方法无效,因为第一个.agg()适用于pyspark.sql.group.GroupedData并返回一个新的DataFrame。对agg的后续调用实际上是pyspark.sql.DataFrame.agg

  

df.groupBy.agg()的缩写

因此,实际上,对agg的第二次调用再次进行了分组,这不是您想要的。