我有一个按预期工作的功能
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf
import scala.collection.mutable.WrappedArray
def arrayContainsAny(s: Seq[String]): UserDefinedFunction = udf((xs: WrappedArray[String]) => !xs.toList.intersect(s).isEmpty)
我需要从UserDefinedFunction
定义
我已尝试过以下
// imports again
def _arrayContainsAny(s: Seq[String])(c: WrappedArray[String]): Boolean = !c.toList.intersect(s).isEmpty
def arrayContainsAny: UserDefinedFunction = udf[Boolean, WrappedArray[String], Seq[String]](_arrayContainsAny)
但它甚至没有编译。
问题似乎是我将函数定义为udf[X, Y, Z]
因此它需要函数(Z, Y) => X
而不是(Z)(Y) => X
有谁知道怎么做?
- β
答案 0 :(得分:1)
选项1 :
使用带有两个参数列表的方法,当包装在UDF中时,您应该传递第一个参数并使用_
来获取结果函数:
def _arrayContainsAny(s: Seq[String])(xs: mutable.WrappedArray[String]) = xs.toList.intersect(s).nonEmpty
def arrayContainsAny(s: Seq[String]): UserDefinedFunction = {
udf(_arrayContainsAny(s) _)
}
选项2 :
您可以创建一个采用Seq[String]
并返回函数WrappedArray[String] => Boolean
的方法,然后在创建UDF时调用该方法:
def _arrayContainsAny(s: Seq[String]) =
(xs: mutable.WrappedArray[String]) => xs.toList.intersect(s).nonEmpty
def arrayContainsAny(s: Seq[String]): UserDefinedFunction = {
udf(_arrayContainsAny(s))
}