我试图为简单的股票分析程序创建我自己的spark scala UserDefinedAggregateFunction,但我在合并功能方面遇到了麻烦。当我在更新方法中填充缓冲区后打印缓冲区的内容时,值似乎已填充。不幸的是,合并函数的缓冲参数没有任何有意义的值。数据看起来好像刚刚初始化。
调试输出多次打印以下行:
DEBUG: AggregateDeltaVolume.merge buffer(0)=0 buffer(1)=0.0
合并函数的缓冲区参数不包含默认值以外的填充字段。
import java.time._
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.functions
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types._
class AggregateDeltaVolume extends UserDefinedAggregateFunction {
override def inputSchema: org.apache.spark.sql.types.StructType = StructType(
Array(StructField("min", DoubleType), StructField("max", DoubleType))
)
override def bufferSchema: StructType = StructType(
StructField("count", LongType) ::
StructField("volumeDeltaSum", DoubleType) :: Nil
)
override def dataType: DataType = DoubleType
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0.0d
}
// Updates the given aggregation buffer buffer with new input data from input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val min:Double = input.getAs[Double](0)
val max:Double = input.getAs[Double](1)
val delta:Double = (max - min) / min
buffer(0) = buffer.getAs[Long](0) + 1L
buffer(1) = buffer.getAs[Double](1) + delta
}
// Merges two aggregation buffers and stores the updated buffer values back to buffer1.
override def merge(buffer: MutableAggregationBuffer, input: Row): Unit = {
println("DEBUG: AggregateDeltaVolume.merge buffer(0)=" + buffer(0) + " buffer(1)=" + buffer(1))
buffer(0) = buffer.getAs[Long](0) + input.getAs[Long](0)
buffer(1) = buffer.getAs[Double](1) + input.getAs[Double](1)
}
override def evaluate(buffer: Row): Any = {
buffer.getDouble(1) / buffer.getLong(0)
}
override def deterministic: Boolean = true
}
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
val adv = new AggregateDeltaVolume
sqlContext.udf.register("adv", adv)
spark.udf.register("adv", adv)
val df = spark.read.format("csv").option("header", "true").load("./data/2018-02-01/2018-02-01_BINS_XETR08.csv")
val dayView = df.groupBy("Mnemonic", "Date").agg(expr("adv(MinPrice, MaxPrice) as AvgDeltaVolume"))
dayView.show