UDF scala返回[max,index]

时间:2016-03-16 08:02:44

标签: scala apache-spark user-defined-functions apache-spark-sql

我想对Spark SQL实现以下功能。给定一个数组返回索引的最大值。我试过了:

/*
 * This function finds the maximum value and corresponding index in the array. NULLs are ignored. 
 * Return type is array in format [max, index], and its element type is the same as the input type.
 * Parameters: x Array[Int]
 * Returns: Array as [max, index].
 */
def array_max_index(x: WrappedArray[Int]): WrappedArray[Int] = {
    val arr = collection.mutable.WrappedArray.empty
    arr.:+(x.max).:+(x.indexOf(x.max))
}

效果很好,但仅适用于Integers - 我希望UDF能够用于其他数值(例如Double s)。我尝试了以下内容,但是我无法返回类型为的结构:

def array_max_index[T](item:Traversable[T])(implicit n:Numeric[T]): Traversable[T] = {
    val arr = collection.mutable.WrappedArray.empty
    val max = item.max
    val index = n.toInt(item.toSeq.indexOf(max))
    arr.:+(max).:+(index)
  }

有什么想法吗?

1 个答案:

答案 0 :(得分:2)

返回Array并不是很有用 - 因为索引类型总是Int,最大值类型取决于特定的调用(如果我理解正确,你希望它能很好地工作整数和双打) - 因此阵列无法正确输入。

这是UDF的一种可能实现方式,返回元组

def array_max_index[T](x: Traversable[T])(implicit n: Numeric[T]): (T, Int) = {
  (x.max, x.toSeq.indexOf(x.max))
}

然后,可以调用Double以及Int s:

sqlContext.udf.register("array_max_index", array_max_index(_: Traversable[Double]))

sqlContext.sql(
  """SELECT array_max_index(array(
    |  CAST(5.0 AS DOUBLE),
    |  CAST(7.0 AS DOUBLE),
    |  CAST(3.0 AS DOUBLE)
    |)) as max_and_index""".stripMargin).show

打印哪些:

+-------------+
|max_and_index|
+-------------+
|      [7.0,1]|
+-------------+