对于两列中的每一列,选择第三列的前n个

时间:2019-02-22 12:37:48

标签: python apache-spark pyspark apache-spark-sql

我有一个包含多列的csv,但是我对年龄,年龄,性别和死亡原因这三个感兴趣。看起来如下。

foo

年龄组代表类别,7代表70-80,cause_of_death是原因的数字代表,1是心脏病发作,2是事故等。

我必须找出每个性别和年龄组的前n个死亡原因。我现在尝试的是以下内容。

|age_group|gender|cause_of_death|
+-------------+---+---------------+
|            7|  F|              1|
|            8|  M|              2|
|           10|  F|              3|
|            7|  M|              2|
|            9|  F|              3|

但这给了我所有死亡人数的降序计数。

data.select("age_group","gender","cause_of_death")\
.groupBy("gender","age_group","cause_of_death").count()\
.sort(desc("age_group")).show()

我想要的是,对于每个年龄段和性别,前n个死亡原因。我怎样才能做到这一点?以下是死亡的三大原因。

gender|age_group|cause_of_death| count|
+---+-------------+---------------+------+
|  F|           11|              7|308181|
|  F|           10|              7|231168|
|  M|           10|              7|221172|
|  M|           11|              7|157693|
|  F|           11|           null|149345|
|  M|            9|              7|146186|
|  F|            9|              7|114424|
|  F|           10|           null|114107|
|  M|            8|              7|106339|
|  M|           10|           null|105508|
|  M|           11|           null| 75934|
|  F|            8|              7| 70390|
|  M|            9|           null| 69363|
|  M|            7|              7| 65634|

编辑: 评论中的问题未回答我的问题,我尝试过,但未获得正确的结果。

代码:

    gender|age_group|cause_of_death| count|
    +---+-------------+---------------+------+
    |  F|           11|              7|308181|
                                     1|291242|
                                     4|234231|

    |  F|           10|              7|231168|
                                     3|221232|
                                     2|192323|

    |  M|           10|              7|221172|
                                     2|142323|
                                     9| 12312

结果

window = Window.partitionBy(data['age_group']).orderBy(data['cause_of_death'].desc())
data = data.select("age_group","gender","cause_of_death")
data.select('*', rank().over(window).alias('rank')).filter(col('rank') <= 5).show() 

0 个答案:

没有答案