定义数值类型SparkSQL scala的函数

时间:2016-03-10 08:37:49

标签: scala types apache-spark user-defined-functions implicit-conversion

我已经定义了以下函数来注册为UDF SparkSQL:

def array_sum(x: WrappedArray[Long]): Long= {
    x.sum
}

我希望这个函数适用于作为参数接收的任何数字类型。我尝试了以下方法:

import Numeric.Implicits._ 
import scala.reflect.ClassTag

def array_sum(x: WrappedArray[NumericType]) = {
   x.sum
}

但它不起作用。有任何想法吗?谢谢!

1 个答案:

答案 0 :(得分:0)

NumericType是特定于Spark SQL的,并且永远不会暴露给接收标准Scala对象的UDF。所以你很可能想要这样的东西:

def array_sum[T : Numeric : ClassTag](x: Seq[T]) = x.sum
udf[Double, Seq[Double]](array_sum _)

虽然看起来不像这里有很多好处。要以正确的方式构建这样的东西,你应该实现自定义表达式。

使用示例:

val rddDouble: RDD[(Long, Array[Double])] = sc.parallelize(Seq(1L, Array(1.0, 2.0)
val double_array_sum = udf[Double, Seq[Double]](array_sum _)
rddDouble.toDF("k", "v").select(double_array_sum($"v")).show

// +------+
// |UDF(v)|
// +------+
// |   3.0|
// +------+

val rddFloat: RDD[(Long, Array[Float])] = sc.parallelize(Seq(
  (1L, Array(1.0f, 2.0f))
))
val float_array_sum = udf[Float, Seq[Float]](array_sum _)
rddFloat.toDF("k", "v").select(float_array_sum($"v")).show

// +------+
// |UDF(v)|
// +------+
// |   3.0|
// +------+