UserDefinedAggregateFunction merge()缓冲区始终未填充

时间:2018-05-15 19:11:59

标签: scala apache-spark user-defined-functions

我试图为简单的股票分析程序创建我自己的spark scala UserDefinedAggregateFunction,但我在合并功能方面遇到了麻烦。当我在更新方法中填充缓冲区后打印缓冲区的内容时,值似乎已填充。不幸的是,合并函数的缓冲参数没有任何有意义的值。数据看起来好像刚刚初始化。

调试输出多次打印以下行:

DEBUG:  AggregateDeltaVolume.merge buffer(0)=0 buffer(1)=0.0

合并函数的缓冲区参数不包含默认值以外的填充字段。

import java.time._
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.functions
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types._

class AggregateDeltaVolume extends UserDefinedAggregateFunction {
  override def inputSchema: org.apache.spark.sql.types.StructType = StructType(
      Array(StructField("min", DoubleType), StructField("max", DoubleType))
  )

  override def bufferSchema: StructType = StructType(
    StructField("count", LongType) ::
    StructField("volumeDeltaSum", DoubleType) :: Nil
  )

  override def dataType: DataType = DoubleType

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0.0d
  }

  // Updates the given aggregation buffer buffer with new input data from input.
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val min:Double = input.getAs[Double](0)
    val max:Double = input.getAs[Double](1)
    val delta:Double = (max - min) / min
    buffer(0) = buffer.getAs[Long](0) + 1L
    buffer(1) = buffer.getAs[Double](1) + delta 
  }

  // Merges two aggregation buffers and stores the updated buffer values back to buffer1.
  override def merge(buffer: MutableAggregationBuffer, input: Row): Unit = {
    println("DEBUG:  AggregateDeltaVolume.merge buffer(0)=" + buffer(0) + " buffer(1)=" + buffer(1))
    buffer(0) = buffer.getAs[Long](0) + input.getAs[Long](0)
    buffer(1) = buffer.getAs[Double](1) + input.getAs[Double](1)
  }

  override def evaluate(buffer: Row): Any = {
    buffer.getDouble(1) / buffer.getLong(0)
  }

  override def deterministic: Boolean = true
}

val sqlContext = new org.apache.spark.sql.SQLContext(sc)

val adv = new AggregateDeltaVolume
sqlContext.udf.register("adv", adv)
spark.udf.register("adv", adv)

val df = spark.read.format("csv").option("header", "true").load("./data/2018-02-01/2018-02-01_BINS_XETR08.csv")

val dayView = df.groupBy("Mnemonic", "Date").agg(expr("adv(MinPrice, MaxPrice) as AvgDeltaVolume"))
dayView.show

0 个答案:

没有答案