在我的项目中,我想实现ADD(+
)功能,但我的参数可能是LongType
,DoubleType
,IntType
。我使用sqlContext.udf.register("add",XXX)
,但我不知道如何编写XXX
,这是为了制作泛型函数。
答案 0 :(得分:5)
您可以通过创建UDF
StructType
来创建通用struct($"col1", $"col2")
,UDF
保存您的值并让UDF
解决此问题。它会作为Row
对象传递到您的val multiAdd = udf[Double,Row](r => {
var n = 0.0
r.toSeq.foreach(n1 => n = n + (n1 match {
case l: Long => l.toDouble
case i: Int => i.toDouble
case d: Double => d
case f: Float => f.toDouble
}))
n
})
val df = Seq((1.0,2),(3.0,4)).toDF("c1","c2")
df.withColumn("add", multiAdd(struct($"c1", $"c2"))).show
+---+---+---+
| c1| c2|add|
+---+---+---+
|1.0| 2|3.0|
|3.0| 4|7.0|
+---+---+---+
,因此您可以执行以下操作:
UDF
你甚至可以做一些有趣的事情,例如将可变数量的列作为输入。实际上,我们上面定义的val df = Seq((1, 2L, 3.0f,4.0),(5, 6L, 7.0f,8.0)).toDF("int","long","float","double")
df.printSchema
root
|-- int: integer (nullable = false)
|-- long: long (nullable = false)
|-- float: float (nullable = false)
|-- double: double (nullable = false)
df.withColumn("add", multiAdd(struct($"int", $"long", $"float", $"double"))).show
+---+----+-----+------+----+
|int|long|float|double| add|
+---+----+-----+------+----+
| 1| 2| 3.0| 4.0|10.0|
| 5| 6| 7.0| 8.0|26.0|
+---+----+-----+------+----+
已经这样做了:
df.withColumn("add", multiAdd(struct(lit(100), $"int", $"long"))).show
+---+----+-----+------+-----+
|int|long|float|double| add|
+---+----+-----+------+-----+
| 1| 2| 3.0| 4.0|103.0|
| 5| 6| 7.0| 8.0|111.0|
+---+----+-----+------+-----+
您甚至可以在混音中添加一个硬编码的数字:
UDF
如果要在SQL语法中使用sqlContext.udf.register("multiAdd", (r: Row) => {
var n = 0.0
r.toSeq.foreach(n1 => n = n + (n1 match {
case l: Long => l.toDouble
case i: Int => i.toDouble
case d: Double => d
case f: Float => f.toDouble
}))
n
})
df.registerTempTable("df")
// Note that 'int' and 'long' are column names
sqlContext.sql("SELECT *, multiAdd(struct(int, long)) as add from df").show
+---+----+-----+------+----+
|int|long|float|double| add|
+---+----+-----+------+----+
| 1| 2| 3.0| 4.0| 3.0|
| 5| 6| 7.0| 8.0|11.0|
+---+----+-----+------+----+
,可以执行以下操作:
sqlContext.sql("SELECT *, multiAdd(struct(*)) as add from df").show
+---+----+-----+------+----+
|int|long|float|double| add|
+---+----+-----+------+----+
| 1| 2| 3.0| 4.0|10.0|
| 5| 6| 7.0| 8.0|26.0|
+---+----+-----+------+----+
这也有效:
SELECT
答案 1 :(得分:2)
我认为您无法注册通用UDF。
如果我们查看register
方法的signature
(实际上,它只是22个register
重载中的一个,用于具有一个参数的UDF,其他的是等效的):
def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction
我们可以看到它使用A1: TypeTag
类型进行了参数化 - TypeTag意味着在注册时,我们必须拥有UDF实际类型的证据 #39;的论点。所以 - 在没有明确输入的情况下传递泛型函数func
无法编译。
对于您的情况,您可以利用Spark自动投射数字类型的能力 - 仅为Double
编写UDF,您也可以将其应用于{{1} } s(输出结果为Int
):
Double