将具有currying的函数插入Spark UserDefinedFunction

时间:2017-09-12 14:08:15

标签: scala apache-spark

我有一个按预期工作的功能

import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf
import scala.collection.mutable.WrappedArray

def arrayContainsAny(s: Seq[String]): UserDefinedFunction = udf((xs: WrappedArray[String]) => !xs.toList.intersect(s).isEmpty)

我需要从UserDefinedFunction定义

中拆分函数

我已尝试过以下

// imports again
def _arrayContainsAny(s: Seq[String])(c: WrappedArray[String]): Boolean = !c.toList.intersect(s).isEmpty
def arrayContainsAny: UserDefinedFunction = udf[Boolean, WrappedArray[String], Seq[String]](_arrayContainsAny)

但它甚至没有编译。

问题似乎是我将函数定义为udf[X, Y, Z]因此它需要函数(Z, Y) => X而不是(Z)(Y) => X

有谁知道怎么做?

- β

1 个答案:

答案 0 :(得分:1)

选项1

使用带有两个参数列表的方法,当包装在UDF中时,您应该传递第一个参数并使用_来获取结果函数:

def _arrayContainsAny(s: Seq[String])(xs: mutable.WrappedArray[String]) = xs.toList.intersect(s).nonEmpty

def arrayContainsAny(s: Seq[String]): UserDefinedFunction = {
  udf(_arrayContainsAny(s) _)
}

选项2

您可以创建一个采用Seq[String]并返回函数WrappedArray[String] => Boolean的方法,然后在创建UDF时调用该方法:

def _arrayContainsAny(s: Seq[String]) = 
  (xs: mutable.WrappedArray[String]) => xs.toList.intersect(s).nonEmpty

def arrayContainsAny(s: Seq[String]): UserDefinedFunction = {
  udf(_arrayContainsAny(s))
}