How do I get Pyspark to aggregate sets at two levels?

时间:2018-03-22 23:46:28

标签: apache-spark pyspark

I need to aggregate rows in a DataFrame by collecting the values in a certain column in each group into a set. pyspark.sql.functions.collect_set does exactly what I need.

However, I need to do this for two columns in turn, because I need to group the input by one column, divide each group into subgroups by another column, and do some aggregation on each subgroup. I don't see how to get collect_set to create a set for each group.

Example:

df = spark.createDataFrame([('a', 'x', 11, 22), ('a', 'y', 33, 44), ('b', 'x', 55, 66), ('b', 'y', 77, 88),('a','x',12,23),('a','y',34,45),('b','x',56,67),('b','y',78,89)], ('col1', 'col2', 'col3', 'col4'))
df.show()
+----+----+----+----+
|col1|col2|col3|col4|
+----+----+----+----+
|   a|   x|  11|  22|
|   a|   y|  33|  44|
|   b|   x|  55|  66|
|   b|   y|  77|  88|
|   a|   x|  12|  23|
|   a|   y|  34|  45|
|   b|   x|  56|  67|
|   b|   y|  78|  89|
+----+----+----+----+

g1 = df.groupBy('col1', 'col2').agg(collect_set('col3'),collect_set('col4'))
g1.show()
+----+----+-----------------+-----------------+
|col1|col2|collect_set(col3)|collect_set(col4)|
+----+----+-----------------+-----------------+
|   a|   x|         [12, 11]|         [22, 23]|
|   b|   y|         [78, 77]|         [88, 89]|
|   a|   y|         [33, 34]|         [45, 44]|
|   b|   x|         [56, 55]|         [66, 67]|
+----+----+-----------------+-----------------+

g2 = g1.groupBy('col1').agg(collect_set('collect_set(col3)'),collect_set('collect_set(col4)'),count('col2'))
g2.show(truncate=False)
+----+--------------------------------------------+--------------------------------------------+-----------+
|col1|collect_set(collect_set(col3))              |collect_set(collect_set(col4))              |count(col2)|
+----+--------------------------------------------+--------------------------------------------+-----------+
|b   |[WrappedArray(56, 55), WrappedArray(78, 77)]|[WrappedArray(66, 67), WrappedArray(88, 89)]|2          |
|a   |[WrappedArray(33, 34), WrappedArray(12, 11)]|[WrappedArray(22, 23), WrappedArray(45, 44)]|2          |
+----+--------------------------+--------------------------------------------+-----------+

I'd like the result to look more like

+----+----------------+----------------+-----------+
|col1|   ...col3...   |   ...col4...   |count(col2)|
+----+----------------+----------------+-----------+
|b   |[56, 55, 78, 77]|[66, 67, 88, 89]|2          |
|a   |[33, 34, 12, 11]|[22, 23, 45, 44]|2          |
+----+----------------+----------------+-----------+

but I don't see an aggregate function to take the union of two or more sets, or a pyspark operation to flatten the "array of arrays" structure that shows up in g2.

Does pyspark provide a simple way to accomplish this? Or is there a totally different approach I should be taking?

3 个答案:

答案 0 :(得分:2)

在PySpark 2.4.5中,您可以使用内置的flatten function

答案 1 :(得分:1)

You can flatten the columns with a UDF afterwards:

flatten = udf(lambda l: [x for i in l for x in i], ArrayType(IntegerType()))

I took the liberty of renaming the columns of g2 as col3 and and col4 to save typing. This gives:

g3 = g2.withColumn('col3flat', flatten('col3'))

>>> g3.show()
+----+--------------------+--------------------+-----+----------------+
|col1|                col3|                col4|count|        col3flat|
+----+--------------------+--------------------+-----+----------------+
|   b|[[78, 77], [56, 55]]|[[66, 67], [88, 89]]|    2|[78, 77, 56, 55]|
|   a|[[12, 11], [33, 34]]|[[22, 23], [45, 44]]|    2|[12, 11, 33, 34]|
+----+--------------------+--------------------+-----+----------------+

答案 2 :(得分:0)

您可以使用

来完成相同的操作
from pyspark.sql.functions import collect_set, countDistinct
(
    df.
        groupby('col1').
        agg(
            collect_set('col3').alias('col3_vals'),
            collect_set('col4').alias('col4_vals'), 
            countDistinct('col2').alias('num_grps')
        ).
        show(truncate=False)
)
+----+----------------+----------------+--------+                               
|col1|col3_vals       |col4_vals       |num_grps|
+----+----------------+----------------+--------+
|b   |[78, 56, 55, 77]|[66, 88, 67, 89]|2       |
|a   |[33, 12, 34, 11]|[45, 22, 44, 23]|2       |
+----+----------------+----------------+--------+