我有一个Spark数据集的以下架构和学生记录。
id | name | subject | score
1 | Tom | Math | 99
1 | Tom | Math | 88
1 | Tom | Physics | 77
2 | Amy | Math | 66
我的目标是将该数据集转移到另一个数据集中,该数据集显示所有学生对每个学科得分最高的记录列表
id | name | subject_score_list
1 | Tom | [(Math, 99), (Physics, 77)]
2 | Amy | [(Math, 66)]
在将数据集转换为Aggregator
键值对之后,我决定使用((id, name), (subject score))
进行转换。
对于缓冲区,我尝试使用可变的Map[String, Integer]
,以便在主题存在且新分数更高时可以更新分数。这是聚合器的外观
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
type StudentSubjectPair = ((String, String), (String, Integer))
type SubjectMap = collection.mutable.Map[String, Integer]
type SubjectList = List[(String, Integer)]
val StudentSubjectAggregator = new Aggregator[StudentSubjectPair, SubjectMap, SubjectList] {
def zero: SubjectMap = collection.mutable.Map[String, Integer]()
def reduce(buf: SubjectMap, input: StudentSubjectPair): SubjectMap = {
if (buf.contains(input._2._1))
buf.map{ case (input._2._1, score) => input._2._1 -> math.max(score, input._2._2) }
else
buf(input._2._1) = input._2._2
buf
}
def merge(b1: SubjectMap, b2: SubjectMap): SubjectMap = {
for ((subject, score) <- b2) {
if (b1.contains(subject))
b1(subject) = math.max(score, b2(subject))
else
b1(subject) = score
}
b1
}
def finish(buf: SubjectMap): SubjectList = buf.toList
override def bufferEncoder: Encoder[SubjectMap] = ExpressionEncoder[SubjectMap]
override def outputEncoder: Encoder[SubjectList] = ExpressionEncoder[SubjectList]
}.toColumn.name("subject_score_list")
我使用Aggregator
是因为我发现它是可自定义的,并且如果我想找到某个主题的平均分数,我可以简单地更改reduce
和merge
函数。
我期望这篇文章有两个答案。
Aggregator
完成这项工作是一种好方法吗?还有其他简单的方法来获得相同的输出吗?collection.mutable.Map[String, Integer]
和List[(String, Integer)]
的正确编码器是什么,因为我总是会遇到以下错误org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 37.0 failed 1 times, most recent failure: Lost task 0.0 in stage 37.0 (TID 231, localhost, executor driver):
java.lang.ClassCastException: scala.collection.immutable.HashMap$HashTrieMap cannot be cast to scala.collection.mutable.Map
at $anon$1.merge(<console>:54)
感谢您的任何输入和帮助,谢谢!
答案 0 :(得分:1)
我认为您可以使用DataFrame API达到所需的结果。
val df= Seq((1 ,"Tom" ,"Math",99),
(1 ,"Tom" ,"Math" ,88),
(1 ,"Tom" ,"Physics" ,77),
(2 ,"Amy" ,"Math" ,66)).toDF("id", "name", "subject","score")
按ID,名称和主题分组,以获取最高分,然后按 id,在主题,得分地图上带有collect_list的名称
df.groupBy("id","name", "subject").agg(max("score").as("score")).groupBy("id","name").
agg(collect_list(map($"subject",$"score")).as("subject_score_list"))
+---+----+--------------------+
| id|name| subject_score_list|
+---+----+--------------------+
| 1| Tom|[[Physics -> 77],...|
| 2| Amy| [[Math -> 66]]|
+---+----+--------------------+