将UDF火花化为结构的自定义排序数组

时间:2020-01-24 18:35:24

标签: scala sorting apache-spark user-defined-functions

我正在尝试使用UDF根据我定义的自定义顺序对结构数组进行排序。

以下是我想要获得的结果类型的示例:

input_tbl
+-------+-------+------+
| id1   | id2   | num  |
+-------+-------+------+
|   1   |   2   |  1   |
|   1   |   3   | -3   |
|   1   |   4   |  2   |
+-------+-------+------+

output_tbl
+-------+-------+------+
| id1   | id2   | num  |
+-------+-------+------+
|   1   |   3   | -3   |
+-------+-------+------+

案例类和UDF的一些示例代码如下所示。

case class Score(id: String, num: Int) extends Ordered[Score] {

  def compare(that: Score): Int = {
    abs(this.num-that.num)
  }
}

val toScoreType : UserDefinedFunction = udf((id: String, num: Int) => {
    Score(id, num)
})

val sortScoreList: UserDefinedFunction = udf((score_list: Array[Score]) => {
    score_list.sorted
})

我正在按如下方式调用sortScore UDF:

val temp = input_tbl
    .select('id1, toScoreType('id2, 'num).as("score"))
    .groupBy('id1)
    .agg((collect_set('score)).as("score_list"))


temp.select('id1, sortScoreList('score_list).as("result"))

但是,我收到“ java.lang.ClassCastException:scala.collection.mutable.WrappedArray $ ofRef”错误。

有人对造成此问题的原因有任何想法吗?

1 个答案:

答案 0 :(得分:2)

Spark无法映射到案例类的记录(结构)作为UDF的输入。实际上,您的函数toScoreType不会转换为案例类(请检查数据模式!),在内部又只是一个结构(即Row)。

您应该重写代码以使用单个UDF:

val sortScoreList: UserDefinedFunction = udf((score_list: Seq[Row]) => {
  score_list.map{case Row(id:String,num:Int) => Score(id,num)}.sorted
})


val temp = input_tbl
  .groupBy('id1)
  .agg((collect_set(struct('id2,'num))).as("score_list"))

temp.select('id1, sortScoreList('score_list).as("result")).show()

但这不会得到期望的结果:

+---+--------------------+
|id1|              result|
+---+--------------------+
|  1|[[2, 1], [3, -3],...|
+---+--------------------+

如果只需要一条记录,则您的UDF应该返回1个案例类,例如:

val sortScoreList: UserDefinedFunction = udf((score_list: Seq[Row]) => {
  score_list.map{case Row(id:String,num:Int) => Score(id,num)}.sorted.head
})

然后将您的结构转换为列:

temp.select('id1, sortScoreList('score_list).as("result"))
  .select($"id1",$"result.*")
  .show()

编辑:

要获得您想要的结果,我会这样做:

case class Score(id: String, num: Int)

val sortScoreList: UserDefinedFunction = udf((score_list: Seq[Row]) => {
      score_list.map{case Row(id:String,num:Int) => Score(id,num)}.minBy(_.num)
 })


temp.select('id1, sortScoreList('score_list).as("result"))
  .select($"id1",$"result.*")
  .show()

+---+---+---+
|id1| id|num|
+---+---+---+
|  1|  3| -3|
+---+---+---+
相关问题