有没有办法计算spark DataFrame中属性的移动中位数?
我希望可以使用窗口函数计算移动中位数(通过使用rowsBetween(0,10)
定义窗口),但没有计算它的功能(类似于average
或{{1 }})。
答案 0 :(得分:4)
这是我扩展UserDefinedAggregateFunction以获得移动中位数的类。
class MyMedian extends org.apache.spark.sql.expressions.UserDefinedAggregateFunction {
def inputSchema: org.apache.spark.sql.types.StructType =
org.apache.spark.sql.types.StructType(org.apache.spark.sql.types.StructField("value", org.apache.spark.sql.types.DoubleType) :: Nil)
def bufferSchema: org.apache.spark.sql.types.StructType = org.apache.spark.sql.types.StructType(
org.apache.spark.sql.types.StructField("window_list", org.apache.spark.sql.types.ArrayType(org.apache.spark.sql.types.DoubleType, false)) :: Nil
)
def dataType: org.apache.spark.sql.types.DataType = org.apache.spark.sql.types.DoubleType
def deterministic: Boolean = true
def initialize(buffer: org.apache.spark.sql.expressions.MutableAggregationBuffer): Unit = {
buffer(0) = new scala.collection.mutable.ArrayBuffer[Double]()
}
def update(buffer: org.apache.spark.sql.expressions.MutableAggregationBuffer,input: org.apache.spark.sql.Row): Unit = {
var bufferVal=buffer.getAs[scala.collection.mutable.WrappedArray[Double]](0).toBuffer
bufferVal+=input.getAs[Double](0)
buffer(0) = bufferVal
}
def merge(buffer1: org.apache.spark.sql.expressions.MutableAggregationBuffer, buffer2: org.apache.spark.sql.Row): Unit = {
buffer1(0) = buffer1.getAs[scala.collection.mutable.ArrayBuffer[Double]](0) ++ buffer2.getAs[scala.collection.mutable.ArrayBuffer[Double]](0)
}
def evaluate(buffer: org.apache.spark.sql.Row): Any = {
var sortedWindow=buffer.getAs[scala.collection.mutable.WrappedArray[Double]](0).sorted.toBuffer
var windowSize=sortedWindow.size
if(windowSize%2==0){
var index=windowSize/2
(sortedWindow(index) + sortedWindow(index-1))/2
}else{
var index=(windowSize+1)/2 - 1
sortedWindow(index)
}
}
}
使用上面的UDAF示例:
// Create an instance of UDAF MyMedian.
val mm = new MyMedian
var movingMedianDS = dataSet.withColumn("MovingMedian", mm(col("value")).over( Window.partitionBy("GroupId").rowsBetween(-10,10)) )
答案 1 :(得分:1)
我想你在这里几乎没有选择。
我认为ntile(2)
(在一个行窗口上)会给你两个“段”,反过来你可以用来计算窗口的中位数。
引用scaladoc:
ntile(n:Int)窗口函数:在有序窗口分区中返回ntile组id(从1到n)。例如,如果n为4,则第一季度的行将获得值1,第二季度将获得2,第三季度将获得3,最后一个季度将获得4。
这相当于SQL中的NTILE函数。
如果一个组中的行数大于另一个组中的行数,请从较大的组中选择最大的行。
如果组中的行数是偶数,请取每组中的最大值和最小值并计算中位数。
我发现在Calculating median using the NTILE function中很好地描述了它。
我认为percent_rank
也可能是计算行窗口中位数的选项。
引用scaladoc:
percent_rank()窗口函数:返回窗口分区中行的相对等级(即百分位数)。
这是通过以下方式计算的:
(rank of row in its partition - 1) / (number of rows in the partition - 1)
这相当于SQL中的PERCENT_RANK函数。
您可以编写用户定义的聚合函数(UDAF)来计算窗口的中位数。
UDAF扩展org.apache.spark.sql.expressions.UserDefinedAggregateFunction(引用scaladoc):
实现用户定义的聚合函数(UDAF)的基类。
幸运的是,在UserDefinedUntypedAggregation示例中有一个自定义UDAF的示例实现。