Spark DataFrame UDF分区列

时间:2017-09-10 16:28:49

标签: apache-spark dataframe user-defined-functions

我想要转换一个列。新列应仅包含原始列的分区。我定义了以下udf:

def extract (index : Integer) = udf((v: Seq[Double]) => v.grouped(16).toSeq(index))

稍后使用

在循环中使用它
myDF = myDF.withColumn("measurement_"+i,extract(i)($"vector"))

使用以下命令创建原始矢量列:

var vectors :Seq[Seq[Double]] = myVectors
vectors.toDF("vector")

但最终我收到以下错误:

Failed to execute user defined function(anonfun$user$sparkapp$MyClass$$extract$2$1: (array<double>) => array<double>)

我是否错误地定义了udf?

1 个答案:

答案 0 :(得分:1)

当我尝试提取不存在的元素时,我可以重现错误,即给出一个大于序列长度的索引:

val myDF = Seq(Seq(1.0, 2.0 ,3, 4.0), Seq(4.0,3,2,1)).toDF("vector")
myDF: org.apache.spark.sql.DataFrame = [vector: array<double>]

def extract (index : Integer) = udf((v: Seq[Double]) => v.grouped(2).toSeq(index))
// extract: (index: Integer)org.apache.spark.sql.expressions.UserDefinedFunction

val i = 2

myDF.withColumn("measurement_"+i,extract(i)($"vector")).show

给出了这个错误:

org.apache.spark.SparkException: Failed to execute user defined function($anonfun$extract$1: (array<double>) => array<double>)

在执行toSeq(index)时,您很可能遇到同样的问题,请尝试使用 toSeq.lift(index) ,如果索引超出范围,则返回None:

def extract (index : Integer) = udf((v: Seq[Double]) => v.grouped(2).toSeq.lift(index))
extract: (index: Integer)org.apache.spark.sql.expressions.UserDefinedFunction

普通索引

val i = 1    
myDF.withColumn("measurement_"+i,extract(i)($"vector")).show
+--------------------+-------------+
|              vector|measurement_1|
+--------------------+-------------+
|[1.0, 2.0, 3.0, 4.0]|   [3.0, 4.0]|
|[4.0, 3.0, 2.0, 1.0]|   [2.0, 1.0]|
+--------------------+-------------+

索引超出范围

val i = 2
myDF.withColumn("measurement_"+i,extract(i)($"vector")).show
+--------------------+-------------+
|              vector|measurement_2|
+--------------------+-------------+
|[1.0, 2.0, 3.0, 4.0]|         null|
|[4.0, 3.0, 2.0, 1.0]|         null|
+--------------------+-------------+