I need to aggregate rows in a DataFrame by collecting the values in a certain column in each group into a set. pyspark.sql.functions.collect_set
does exactly what I need.
However, I need to do this for two columns in turn, because I need to group the input by one column, divide each group into subgroups by another column, and do some aggregation on each subgroup. I don't see how to get collect_set
to create a set for each group.
Example:
df = spark.createDataFrame([('a', 'x', 11, 22), ('a', 'y', 33, 44), ('b', 'x', 55, 66), ('b', 'y', 77, 88),('a','x',12,23),('a','y',34,45),('b','x',56,67),('b','y',78,89)], ('col1', 'col2', 'col3', 'col4'))
df.show()
+----+----+----+----+
|col1|col2|col3|col4|
+----+----+----+----+
| a| x| 11| 22|
| a| y| 33| 44|
| b| x| 55| 66|
| b| y| 77| 88|
| a| x| 12| 23|
| a| y| 34| 45|
| b| x| 56| 67|
| b| y| 78| 89|
+----+----+----+----+
g1 = df.groupBy('col1', 'col2').agg(collect_set('col3'),collect_set('col4'))
g1.show()
+----+----+-----------------+-----------------+
|col1|col2|collect_set(col3)|collect_set(col4)|
+----+----+-----------------+-----------------+
| a| x| [12, 11]| [22, 23]|
| b| y| [78, 77]| [88, 89]|
| a| y| [33, 34]| [45, 44]|
| b| x| [56, 55]| [66, 67]|
+----+----+-----------------+-----------------+
g2 = g1.groupBy('col1').agg(collect_set('collect_set(col3)'),collect_set('collect_set(col4)'),count('col2'))
g2.show(truncate=False)
+----+--------------------------------------------+--------------------------------------------+-----------+
|col1|collect_set(collect_set(col3)) |collect_set(collect_set(col4)) |count(col2)|
+----+--------------------------------------------+--------------------------------------------+-----------+
|b |[WrappedArray(56, 55), WrappedArray(78, 77)]|[WrappedArray(66, 67), WrappedArray(88, 89)]|2 |
|a |[WrappedArray(33, 34), WrappedArray(12, 11)]|[WrappedArray(22, 23), WrappedArray(45, 44)]|2 |
+----+--------------------------+--------------------------------------------+-----------+
I'd like the result to look more like
+----+----------------+----------------+-----------+
|col1| ...col3... | ...col4... |count(col2)|
+----+----------------+----------------+-----------+
|b |[56, 55, 78, 77]|[66, 67, 88, 89]|2 |
|a |[33, 34, 12, 11]|[22, 23, 45, 44]|2 |
+----+----------------+----------------+-----------+
but I don't see an aggregate function to take the union of two or more sets, or a pyspark operation to flatten the "array of arrays" structure that shows up in g2
.
Does pyspark provide a simple way to accomplish this? Or is there a totally different approach I should be taking?
答案 0 :(得分:2)
在PySpark 2.4.5中,您可以使用内置的flatten
function。
答案 1 :(得分:1)
You can flatten the columns with a UDF afterwards:
flatten = udf(lambda l: [x for i in l for x in i], ArrayType(IntegerType()))
I took the liberty of renaming the columns of g2
as col3
and and col4
to save typing. This gives:
g3 = g2.withColumn('col3flat', flatten('col3'))
>>> g3.show()
+----+--------------------+--------------------+-----+----------------+
|col1| col3| col4|count| col3flat|
+----+--------------------+--------------------+-----+----------------+
| b|[[78, 77], [56, 55]]|[[66, 67], [88, 89]]| 2|[78, 77, 56, 55]|
| a|[[12, 11], [33, 34]]|[[22, 23], [45, 44]]| 2|[12, 11, 33, 34]|
+----+--------------------+--------------------+-----+----------------+
答案 2 :(得分:0)
您可以使用
来完成相同的操作from pyspark.sql.functions import collect_set, countDistinct
(
df.
groupby('col1').
agg(
collect_set('col3').alias('col3_vals'),
collect_set('col4').alias('col4_vals'),
countDistinct('col2').alias('num_grps')
).
show(truncate=False)
)
+----+----------------+----------------+--------+
|col1|col3_vals |col4_vals |num_grps|
+----+----------------+----------------+--------+
|b |[78, 56, 55, 77]|[66, 88, 67, 89]|2 |
|a |[33, 12, 34, 11]|[45, 22, 44, 23]|2 |
+----+----------------+----------------+--------+