我一直在尝试定义一个在Spark的DataFrame中运行的函数,它将scala集作为输入并输出一个整数。我收到以下错误:
org.apache.spark.SparkException: Job aborted due to stage failure: Task 20 in stage 25.0 failed 1 times, most recent failure: Lost task 20.0 in stage 25.0 (TID 473, localhost): java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to scala.collection.immutable.Set
这是一个简单的代码,它给出了问题的关键:
// generate sample data
case class Dummy( x:Array[Integer] )
val df = sqlContext.createDataFrame(Seq(
Dummy(Array(1,2,3)),
Dummy(Array(10,20,30,40))
))
// define the UDF
import org.apache.spark.sql.functions._
def setSize(A:Set[Integer]):Integer = {
A.size
}
// For some reason I couldn't get it to work without this valued function
val sizeWrap: (Set[Integer] => Integer) = setSize(_)
val sizeUDF = udf(sizeWrap)
// this produces the error
df.withColumn("colSize", sizeUDF('x)).show
我在这里缺少什么?我怎样才能让它发挥作用?我知道我可以通过转换为RDD来做到这一点,但我不想在RDD和DataFrame之间来回切换。
答案 0 :(得分:1)
使用Seq
:
val sizeUDF = udf((x: Seq) => setSize(x.toSet))