pyspark - 创建Top3组并聚合其他组/行

时间:2017-07-24 17:19:55

标签: python apache-spark dataframe aggregation

我想创建一个新的dataFrame,其中列type将成为基于最高 count的topX。 其他类型(其他)将是同一组name的所有typeX的 sum 。< / p> DF的

data = spark.createDataFrame([
      ("name1", "type1", 2), ("name1", "type2", 1), ("name1", "type3", 4), ("name1", "type3", 5), \
      ("name2", "type1", 6), ("name2", "type1", 7), ("name2", "type2", 8) \
    ],["name", "type", "cnt"])
    data.printSchema()

是什么:

|name  |type|cnt|
|------|-----------
|name1 |typeA|  6|
|name1 |typeX|  5|
|name1 |typeW|  3|
|name1 |typeZ|  1|
|name2 |typeA|  7|
|name2 |typeB|  2|
| .... | ... |   |  

结果Dataframe(前2名)将是: 每个名称都有top2值+&#39;其他&#39; (3组)

|name  |type|cnt|
|------|-----------
|name1 |typeA|  6|
|name1 |typeX|  5|
|name1 |other|  4|
|name2 |typeA|  7|
|name2 |typeB|  2|
|name2 |other|  0|
| .... | ... |   |  

我不知道如何跳过特定组的X行,然后开始聚合剩余的行。

1 个答案:

答案 0 :(得分:2)

我尝试使用窗口函数以及基于name和cnt的行级别,然后过滤每个名称的前2个排名并聚合其他名称,最后将它们联合起来。

>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.getOrCreate()
>>> data = spark.createDataFrame([
  ("name1", "type1", 2), ("name1", "type2", 1), ("name1", "type3", 4), ("name1", "type3", 5), \
  ("name2", "type1", 6), ("name2", "type1", 7), ("name2", "type2", 8) \
],["name", "type", "cnt"])
>>> data.show()
+-----+-----+---+
| name| type|cnt|
+-----+-----+---+
|name1|type1|  2|
|name1|type2|  1|
|name1|type3|  4|
|name1|type3|  5|
|name2|type1|  6|
|name2|type1|  7|
|name2|type2|  8|
+-----+-----+---+

>>> from pyspark.sql.window import Window
>>> from pyspark.sql.functions import rank, col,lit
>>> window = Window.partitionBy(data['name']).orderBy(data['cnt'].desc())
>>> data1 = data.select('*', rank().over(window).alias('rank'))
>>> data1.show()
+-----+-----+---+----+
| name| type|cnt|rank|
+-----+-----+---+----+
|name1|type3|  5|   1|
|name1|type3|  4|   2|
|name1|type1|  2|   3|
|name1|type2|  1|   4|
|name2|type2|  8|   1|
|name2|type1|  7|   2|
|name2|type1|  6|   3|
+-----+-----+---+----+
>>> data2 = data1.filter(data1['rank'] > 2).groupby('name').sum('cnt').select('name',lit('other').alias('type'),col('sum(cnt)').alias('cnt'))
>>> data2.show()
+-----+-----+---+
| name| type|cnt|
+-----+-----+---+
|name1|other|  3|
|name2|other|  6|
+-----+-----+---+
>>> data1.filter(data1['rank'] <=2).select('name','type','cnt').union(data2).show()
+-----+-----+---+
| name| type|cnt|
+-----+-----+---+
|name1|type3|  5|
|name1|type3|  4|
|name2|type2|  8|
|name2|type1|  7|
|name1|other|  3|
|name2|other|  6|
+-----+-----+---+