我正在尝试使用UDF根据我定义的自定义顺序对结构数组进行排序。
以下是我想要获得的结果类型的示例:
input_tbl
+-------+-------+------+
| id1 | id2 | num |
+-------+-------+------+
| 1 | 2 | 1 |
| 1 | 3 | -3 |
| 1 | 4 | 2 |
+-------+-------+------+
output_tbl
+-------+-------+------+
| id1 | id2 | num |
+-------+-------+------+
| 1 | 3 | -3 |
+-------+-------+------+
案例类和UDF的一些示例代码如下所示。
case class Score(id: String, num: Int) extends Ordered[Score] {
def compare(that: Score): Int = {
abs(this.num-that.num)
}
}
val toScoreType : UserDefinedFunction = udf((id: String, num: Int) => {
Score(id, num)
})
val sortScoreList: UserDefinedFunction = udf((score_list: Array[Score]) => {
score_list.sorted
})
我正在按如下方式调用sortScore UDF:
val temp = input_tbl
.select('id1, toScoreType('id2, 'num).as("score"))
.groupBy('id1)
.agg((collect_set('score)).as("score_list"))
temp.select('id1, sortScoreList('score_list).as("result"))
但是,我收到“ java.lang.ClassCastException:scala.collection.mutable.WrappedArray $ ofRef”错误。
有人对造成此问题的原因有任何想法吗?
答案 0 :(得分:2)
Spark无法映射到案例类的记录(结构)作为UDF的输入。实际上,您的函数toScoreType
不会转换为案例类(请检查数据模式!),在内部又只是一个结构(即Row
)。
您应该重写代码以使用单个UDF:
val sortScoreList: UserDefinedFunction = udf((score_list: Seq[Row]) => {
score_list.map{case Row(id:String,num:Int) => Score(id,num)}.sorted
})
val temp = input_tbl
.groupBy('id1)
.agg((collect_set(struct('id2,'num))).as("score_list"))
temp.select('id1, sortScoreList('score_list).as("result")).show()
但这不会得到期望的结果:
+---+--------------------+
|id1| result|
+---+--------------------+
| 1|[[2, 1], [3, -3],...|
+---+--------------------+
如果只需要一条记录,则您的UDF应该返回1个案例类,例如:
val sortScoreList: UserDefinedFunction = udf((score_list: Seq[Row]) => {
score_list.map{case Row(id:String,num:Int) => Score(id,num)}.sorted.head
})
然后将您的结构转换为列:
temp.select('id1, sortScoreList('score_list).as("result"))
.select($"id1",$"result.*")
.show()
编辑:
要获得您想要的结果,我会这样做:
case class Score(id: String, num: Int)
val sortScoreList: UserDefinedFunction = udf((score_list: Seq[Row]) => {
score_list.map{case Row(id:String,num:Int) => Score(id,num)}.minBy(_.num)
})
temp.select('id1, sortScoreList('score_list).as("result"))
.select($"id1",$"result.*")
.show()
+---+---+---+
|id1| id|num|
+---+---+---+
| 1| 3| -3|
+---+---+---+