我正在使用“ rowsBetween”窗口函数来计算移动中位数,如下所示
val mm = new MovingMedian
var rawdataFiltered = rawdata.withColumn("movingmedian", mm(col("value")).over( Window.partitionBy("raw_data_field_id").orderBy("date_time_epoch").rowsBetween(-50,50)) )
我正在向前50排窗口,当前排后50排窗口。 但是我需要在开头和结尾处排除在当前行之前或之后没有50行的任何行。
参考代码:
class MovingMedian extends org.apache.spark.sql.expressions.UserDefinedAggregateFunction {
def inputSchema: org.apache.spark.sql.types.StructType =
org.apache.spark.sql.types.StructType(org.apache.spark.sql.types.StructField("value", org.apache.spark.sql.types.DoubleType) :: Nil)
def bufferSchema: org.apache.spark.sql.types.StructType = org.apache.spark.sql.types.StructType(
org.apache.spark.sql.types.StructField("window_list", org.apache.spark.sql.types.ArrayType(org.apache.spark.sql.types.DoubleType, false)) :: Nil
)
def dataType: org.apache.spark.sql.types.DataType = org.apache.spark.sql.types.DoubleType
def deterministic: Boolean = true
def initialize(buffer: org.apache.spark.sql.expressions.MutableAggregationBuffer): Unit = {
buffer(0) = new scala.collection.mutable.ArrayBuffer[Double]()
}
def update(buffer: org.apache.spark.sql.expressions.MutableAggregationBuffer,input: org.apache.spark.sql.Row): Unit = {
var bufferVal=buffer.getAs[scala.collection.mutable.WrappedArray[Double]](0).toBuffer
bufferVal+=input.getAs[Double](0)
buffer(0) = bufferVal
}
def merge(buffer1: org.apache.spark.sql.expressions.MutableAggregationBuffer, buffer2: org.apache.spark.sql.Row): Unit = {
buffer1(0) = buffer1.getAs[scala.collection.mutable.ArrayBuffer[Double]](0) ++ buffer2.getAs[scala.collection.mutable.ArrayBuffer[Double]](0)
}
def evaluate(buffer: org.apache.spark.sql.Row): Any = {
var sortedWindow=buffer.getAs[scala.collection.mutable.WrappedArray[Double]](0).sorted.toBuffer
var windowSize=sortedWindow.size
if(windowSize%2==0){
var index=windowSize/2
(sortedWindow(index) + sortedWindow(index-1))/2
}else{
var index=(windowSize+1)/2 - 1
sortedWindow(index)
}
}
}
答案 0 :(得分:0)
您可以按窗口大小进行过滤:
val df = Seq(1, 2, 3, 4, 5).toDF("foo")
val win = Window.orderBy("foo").rowsBetween(-1, 1)
df.select($"foo",
collect_list($"foo") over win as "agg",
count($"*") over win as "cnt")
.filter($"cnt" === 3)
.show()
输出:
+---+---------+---+
|foo| agg|cnt|
+---+---------+---+
| 2|[1, 2, 3]| 3|
| 3|[2, 3, 4]| 3|
| 4|[3, 4, 5]| 3|
+---+---------+---+