我在spark中写了一个UDAF,它计算整数的范围表示。
我的中间结果是ArrayBuffers,最终结果也是ArrayBuffer。但是当我运行代码时,我收到了这个错误 -
org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0, localhost): java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to scala.collection.mutable.ArrayBuffer
at $iwC$$iwC$Concat.update(<console>:33)
at org.apache.spark.sql.execution.aggregate.ScalaUDAF.update(udaf.scala:445)
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$11.apply(AggregationIterator.scala:178)
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$11.apply(AggregationIterator.scala:171)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.processCurrentSortedGroup(SortBasedAggregationIterator.scala:100)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:139)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:30)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.insertAll(BypassMergeSortShuffleWriter.java:119)
at org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:73)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:73)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:41)
at org.apache.spark.scheduler.Task.run(Task.scala:88)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)
at java.lang.Thread.run(Thread.java:745)
这是我的聚合函数 -
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.ArrayType
import scala.collection.mutable.ArrayBuffer
class Concat extends UserDefinedAggregateFunction {
def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("value", LongType) :: Nil)
def bufferSchema: StructType = StructType(
StructField("concatenation",ArrayType(LongType,false) ) :: Nil
)
def dataType: DataType = ArrayType(LongType,false)
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, new ArrayBuffer[Long]() )
}
def update(buffer: MutableAggregationBuffer,input: Row): Unit = {
val l=buffer.getSeq(0).asInstanceOf[ ArrayBuffer[Long] ]
val v=input.getAs[ Long ](0)
val n=l.size
if(n >= 2){
val x1=l(n-2)
val x2=l(n-1)
if( x1-1 == v)
l(n-2)=v
else if(x1+x2+1 == v)
l(n-1)=x2+1
else
l += v
l += 0L
}
else{
l += v
l += 0L
}
buffer.update(0,l)
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val a=buffer1.getSeq(0).asInstanceOf[ ArrayBuffer[Long] ]
val b=buffer2.getSeq(0).asInstanceOf[ ArrayBuffer[Long] ]
a ++ b
}
def evaluate(buffer: Row): Any = {
buffer(0)
}
}
我也查看了udaf.scala,但是我无法弄清楚如何使它工作&amp;我对scala不太熟练。我怎样才能使它发挥作用?