用户定义的汇总函数(UDAF)始终处于“返回零”状态

时间:2018-06-19 14:31:30

标签: scala apache-spark collections user-defined-functions

我在两列中有一个带有unix时间戳记的数据帧,我正在编写一个用户定义的聚合函数,该函数将在数据帧的列中提供的指定范围内的索引上返回值为1的数组。输入数据框如下所示

    +------+---------------+--------------+      
    |EQ_KEY|Unix_start_time|Unix_stop_time| 
    +------+---------------+--------------+ 
    | 0    | 1349366011    | 1508753519   | 
    +------+---------------+--------------+

我为此使用的自定义UDAF是

// Custom Function modification

class Chronos extends UserDefinedAggregateFunction {
  import scala.collection.mutable.WrappedArray
  val bucketSize = 175321 

  override def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(StructField("Unix_start_time", LongType) :: StructField("Unix_stop_time", LongType) :: Nil)

  override def bufferSchema: StructType = StructType(
  StructField("kalam", ArrayType(ShortType,false)) :: Nil)

  override def dataType: DataType = ArrayType(ShortType,false)

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
  buffer(0) = new Array[Short](bucketSize)
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      val index_start:Int = ((input.getLong(0) - 946684800 +1)/3600).toInt

      val index_end:Int = ((input.getLong(1)- 946684800 +1)/3600).toInt

      val len_to = index_end - index_start + 1
      //val to_fill = new Array[Short](len_to)


      val arr_in = buffer.getAs[WrappedArray[Short]](0).toArray
      val to_fill = Array.fill[Short](len_to)(1)
      to_fill.copyToArray(arr_in,index_start,len_to)

      buffer(0) = arr_in   // TODO THIS TAKES WAYYYYY TOO LONG - it actually copies the entire array for every call to this method (which essentially updates only 1 cell)
    }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val arr1 = buffer1.getAs[WrappedArray[Short]](0).toArray
    val arr2 = buffer2.getAs[WrappedArray[Short]](0).toArray

    for(i <- arr1.indices){
      var updated_value:Short = 0
      if(arr1(i) == 1){updated_value = 1}else{if(arr2(i)==1){updated_value = 1}}
      arr1.update(i, updated_value)
    }
    buffer1(0) = arr1
  }

  override def evaluate(buffer: Row): Any = {
  buffer.getAs[WrappedArray[Short]](0)
  }
}

这里的想法是要计算每个KEY在随后的行中提供的非重叠时间段,但是Custom函数返回的数组中每个键的数组都为零。

其中,所需的输出将是一个固定大小的数组,其中1是,其中指定的时间段与每个键中的行一起传递。 请指出我的功能失败的地方。

0 个答案:

没有答案