如何在Spark数据集上应用可定制的Aggregator?

时间:2019-07-30 17:49:55

标签: scala apache-spark apache-spark-dataset

我有一个Spark数据集的以下架构和学生记录。

id | name | subject | score
1  | Tom  | Math    | 99
1  | Tom  | Math    | 88
1  | Tom  | Physics | 77
2  | Amy  | Math    | 66

我的目标是将该数据集转移到另一个数据集中,该数据集显示所有学生对每个学科得分最高的记录列表

id | name | subject_score_list
1  | Tom  | [(Math, 99), (Physics, 77)]
2  | Amy  | [(Math, 66)]

在将数据集转换为Aggregator键值对之后,我决定使用((id, name), (subject score))进行转换。

对于缓冲区,我尝试使用可变的Map[String, Integer],以便在主题存在且新分数更高时可以更新分数。这是聚合器的外观

import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator

type StudentSubjectPair = ((String, String), (String, Integer))
type SubjectMap = collection.mutable.Map[String, Integer]
type SubjectList = List[(String, Integer)]

val StudentSubjectAggregator = new Aggregator[StudentSubjectPair, SubjectMap, SubjectList] {
  def zero: SubjectMap = collection.mutable.Map[String, Integer]()

  def reduce(buf: SubjectMap, input: StudentSubjectPair): SubjectMap = {
    if (buf.contains(input._2._1))
      buf.map{ case (input._2._1, score) => input._2._1 -> math.max(score, input._2._2) }
    else
      buf(input._2._1) = input._2._2
    buf
  }

  def merge(b1: SubjectMap, b2: SubjectMap): SubjectMap = {
    for ((subject, score) <- b2) {
      if (b1.contains(subject))
        b1(subject) = math.max(score, b2(subject))
      else
        b1(subject) = score
    }
    b1
  }

  def finish(buf: SubjectMap): SubjectList = buf.toList

  override def bufferEncoder: Encoder[SubjectMap] = ExpressionEncoder[SubjectMap]
  override def outputEncoder: Encoder[SubjectList] = ExpressionEncoder[SubjectList]
}.toColumn.name("subject_score_list")

我使用Aggregator是因为我发现它是可自定义的,并且如果我想找到某个主题的平均分数,我可以简单地更改reducemerge函数。 我期望这篇文章有两个答案。

  1. 使用Aggregator完成这项工作是一种好方法吗?还有其他简单的方法来获得相同的输出吗?
  2. collection.mutable.Map[String, Integer]List[(String, Integer)]的正确编码器是什么,因为我总是会遇到以下错误
org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 37.0 failed 1 times, most recent failure: Lost task 0.0 in stage 37.0 (TID 231, localhost, executor driver):
java.lang.ClassCastException: scala.collection.immutable.HashMap$HashTrieMap cannot be cast to scala.collection.mutable.Map
    at $anon$1.merge(<console>:54)

感谢您的任何输入和帮助,谢谢!

1 个答案:

答案 0 :(得分:1)

我认为您可以使用DataFrame API达到所需的结果。

val df= Seq((1 ,"Tom" ,"Math",99),
    (1 ,"Tom" ,"Math" ,88),
    (1 ,"Tom" ,"Physics" ,77),
    (2 ,"Amy" ,"Math"  ,66)).toDF("id", "name", "subject","score")

按ID,名称和主题分组,以获取最高分,然后按 id,在主题,得分地图上带有collect_list的名称

df.groupBy("id","name", "subject").agg(max("score").as("score")).groupBy("id","name").
    agg(collect_list(map($"subject",$"score")).as("subject_score_list"))


+---+----+--------------------+
| id|name|  subject_score_list|
+---+----+--------------------+
|  1| Tom|[[Physics -> 77],...|
|  2| Amy|      [[Math -> 66]]|
+---+----+--------------------+