在Spark UDAF中使用arraybuffers

时间:2016-04-02 17:38:14

标签: apache-spark apache-spark-sql spark-dataframe

我在spark中写了一个UDAF,它计算整数的范围表示。

我的中间结果是ArrayBuffers,最终结果也是ArrayBuffer。但是当我运行代码时,我收到了这个错误 -

org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0, localhost): java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to scala.collection.mutable.ArrayBuffer
    at $iwC$$iwC$Concat.update(<console>:33)
    at org.apache.spark.sql.execution.aggregate.ScalaUDAF.update(udaf.scala:445)
    at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$11.apply(AggregationIterator.scala:178)
    at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$11.apply(AggregationIterator.scala:171)
    at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.processCurrentSortedGroup(SortBasedAggregationIterator.scala:100)
    at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:139)
    at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:30)
    at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
    at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
    at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.insertAll(BypassMergeSortShuffleWriter.java:119)
    at org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:73)
    at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:73)
    at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:41)
    at org.apache.spark.scheduler.Task.run(Task.scala:88)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)
    at java.lang.Thread.run(Thread.java:745)

这是我的聚合函数 -

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.ArrayType
import scala.collection.mutable.ArrayBuffer

class Concat extends UserDefinedAggregateFunction {
  def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(StructField("value", LongType) :: Nil)

  def bufferSchema: StructType = StructType(
    StructField("concatenation",ArrayType(LongType,false) ) :: Nil
  )

  def dataType: DataType = ArrayType(LongType,false)

  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, new ArrayBuffer[Long]() )
  }

  def update(buffer: MutableAggregationBuffer,input: Row): Unit = {
    val l=buffer.getSeq(0).asInstanceOf[ ArrayBuffer[Long] ]
    val v=input.getAs[ Long ](0)

    val n=l.size
    if(n >= 2){
      val x1=l(n-2)
      val x2=l(n-1)

      if( x1-1 == v)
        l(n-2)=v
      else if(x1+x2+1 == v)
        l(n-1)=x2+1
      else
         l += v
         l += 0L
     }
     else{
       l += v
       l += 0L
    }

    buffer.update(0,l)
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val a=buffer1.getSeq(0).asInstanceOf[ ArrayBuffer[Long] ]
    val b=buffer2.getSeq(0).asInstanceOf[ ArrayBuffer[Long] ]

    a ++ b
  }

  def evaluate(buffer: Row): Any = {
    buffer(0)
  }
}

我也查看了udaf.scala,但是我无法弄清楚如何使它工作&amp;我对scala不太熟练。我怎样才能使它发挥作用?

0 个答案:

没有答案