如何在DataFrame中计算移动中位数?

时间:2017-05-19 04:36:50

标签: apache-spark apache-spark-sql

有没有办法计算spark DataFrame中属性的移动中位数

我希望可以使用窗口函数计算移动中位数(通过使用rowsBetween(0,10)定义窗口),但没有计算它的功能(类似于average或{{1 }})。

2 个答案:

答案 0 :(得分:4)

这是我扩展UserDefinedAggregateFunction以获得移动中位数的类。

class MyMedian extends org.apache.spark.sql.expressions.UserDefinedAggregateFunction {
  def inputSchema: org.apache.spark.sql.types.StructType =
    org.apache.spark.sql.types.StructType(org.apache.spark.sql.types.StructField("value", org.apache.spark.sql.types.DoubleType) :: Nil)

  def bufferSchema: org.apache.spark.sql.types.StructType = org.apache.spark.sql.types.StructType(
    org.apache.spark.sql.types.StructField("window_list", org.apache.spark.sql.types.ArrayType(org.apache.spark.sql.types.DoubleType, false)) :: Nil
  )
  def dataType: org.apache.spark.sql.types.DataType = org.apache.spark.sql.types.DoubleType
  def deterministic: Boolean = true
  def initialize(buffer: org.apache.spark.sql.expressions.MutableAggregationBuffer): Unit = {
    buffer(0) = new scala.collection.mutable.ArrayBuffer[Double]()
  }
  def update(buffer: org.apache.spark.sql.expressions.MutableAggregationBuffer,input: org.apache.spark.sql.Row): Unit = {
    var bufferVal=buffer.getAs[scala.collection.mutable.WrappedArray[Double]](0).toBuffer
    bufferVal+=input.getAs[Double](0)
    buffer(0) = bufferVal
  }
  def merge(buffer1: org.apache.spark.sql.expressions.MutableAggregationBuffer, buffer2: org.apache.spark.sql.Row): Unit = {
    buffer1(0) = buffer1.getAs[scala.collection.mutable.ArrayBuffer[Double]](0) ++ buffer2.getAs[scala.collection.mutable.ArrayBuffer[Double]](0)
  }
  def evaluate(buffer: org.apache.spark.sql.Row): Any = {
      var sortedWindow=buffer.getAs[scala.collection.mutable.WrappedArray[Double]](0).sorted.toBuffer
      var windowSize=sortedWindow.size
      if(windowSize%2==0){
          var index=windowSize/2
          (sortedWindow(index) + sortedWindow(index-1))/2
      }else{
          var index=(windowSize+1)/2 - 1
          sortedWindow(index)
      }
  }
}

使用上面的UDAF示例:

// Create an instance of UDAF MyMedian.
val mm = new MyMedian

var movingMedianDS = dataSet.withColumn("MovingMedian", mm(col("value")).over( Window.partitionBy("GroupId").rowsBetween(-10,10)) )

答案 1 :(得分:1)

你在这里几乎没有选择。

窗口功能

我认为ntile(2)(在一个行窗口上)会给你两个“段”,反过来你可以用来计算窗口的中位数。

引用scaladoc

  

ntile(n:Int)窗口函数:在有序窗口分区中返回ntile组id(从1到n)。例如,如果n为4,则第一季度的行将获得值1,第二季度将获得2,​​第三季度将获得3,最后一个季度将获得4。

     

这相当于SQL中的NTILE函数。

如果一个组中的行数大于另一个组中的行数,请从较大的组中选择最大的行。

如果组中的行数是偶数,请取每组中的最大值和最小值并计算中位数。

我发现在Calculating median using the NTILE function中很好地描述了它。

percent_rank窗口函数

我认为percent_rank也可能是计算行窗口中位数的选项。

引用scaladoc

  

percent_rank()窗口函数:返回窗口分区中行的相对等级(即百分位数)。

     

这是通过以下方式计算的:

     

(rank of row in its partition - 1) / (number of rows in the partition - 1)

     

这相当于SQL中的PERCENT_RANK函数。

用户定义的聚合函数(UDAF)

您可以编写用户定义的聚合函数(UDAF)来计算窗口的中位数。

UDAF扩展org.apache.spark.sql.expressions.UserDefinedAggregateFunction(引用scaladoc):

  

实现用户定义的聚合函数(UDAF)的基类。

幸运的是,在UserDefinedUntypedAggregation示例中有一个自定义UDAF的示例实现。