Apache Spark SQL:如何使用GroupBy和Max过滤数据

时间:2019-06-07 01:53:57

标签: apache-spark-sql

我有一个具有以下结构的给定数据集:

https://i.imgur.com/Kk7I1S1.png

我需要使用SparkSQL解决以下问题:数据框

对于每个邮政编码,查找以前发生事故次数最多的客户。如果是平局,则意味着不止一个客户发生相同数量的事故,只需退还其中任何一个即可。对于这些选定的客户中的每一个,输出以下列:邮政编码,客户ID,先前的事故数量。

2 个答案:

答案 0 :(得分:0)

我认为您错过了提供图像链接中提到的数据的机会。我以您的问题为参考,创建了自己的数据集。您可以使用以下代码段作为参考,也可以将df数据框替换为数据集以添加所需的列,例如id等。

      scala> val df  = spark.read.format("csv").option("header","true").load("/user/nikhil/acc.csv")
        df: org.apache.spark.sql.DataFrame = [postcode: string, customer: string ... 1 more field]

        scala> df.show()
        +--------+--------+---------+
        |postcode|customer|accidents|
        +--------+--------+---------+
        |       1|  Nikhil|        5|
        |       2|     Ram|        4|
        |       1|   Shyam|        3|
        |       3|  pranav|        1|
        |       1|   Suman|        2|
        |       3|    alex|        2|
        |       2|     Raj|        5|
        |       4|   arpit|        3|
        |       1|   darsh|        2|
        |       1|   rahul|        3|
        |       2|   kiran|        4|
        |       3|    baba|        4|
        |       4|    alok|        3|
        |       1|   Nakul|        5|
        +--------+--------+---------+


        scala> df.createOrReplaceTempView("tmptable")

   scala> spark.sql(s"""SELECT postcode,customer, accidents FROM (SELECT postcode,customer, accidents, row_number() over (PARTITION BY postcode ORDER BY accidents desc) as rn  from tmptable) WHERE rn = 1""").show(false)
+--------+--------+---------+                                                   
|postcode|customer|accidents|
+--------+--------+---------+
|3       |baba    |4        |
|1       |Nikhil  |5        |
|4       |arpit   |3        |
|2       |Raj     |5        |
+--------+--------+---------+

答案 1 :(得分:0)

您可以在python中使用以下代码获得结果:

from pyspark.sql import Row, Window
import pyspark.sql.functions as F
from pyspark.sql.window import *

l = [(1, '682308', 25), (1, '682308', 23), (2, '682309', 23), (1, '682309', 27), (2, '682309', 22)]
rdd = sc.parallelize(l)
people = rdd.map(lambda x: Row(c_id=int(x[0]), postcode=x[1], accident=int(x[2])))
schemaPeople = sqlContext.createDataFrame(people)
result = schemaPeople.groupby("postcode", "c_id").agg(F.max("accident").alias("accident"))
new_result = result.withColumn("row_num", F.row_number().over(Window.partitionBy("postcode").orderBy(F.desc("accident")))).filter("row_num==1")
new_result.show()