在窗口数据帧上计算四分位数

时间:2018-07-12 00:23:07

标签: scala apache-spark apache-spark-sql

我有一些数据,为了便于讨论,将其提供给:

val schema = Seq("id", "day", "value")
val data = Seq(
    (1, 1, 1), 
    (1, 2, 11),
    (1, 3, 1), 
    (1, 4, 11),
    (1, 5, 1), 
    (1, 6, 11),
    (2, 1, 1), 
    (2, 2, 11),
    (2, 3, 1), 
    (2, 4, 11),
    (2, 5, 1), 
    (2, 6, 11) 
  )   

val df = sc.parallelize(data).toDF(schema: _*) 

我想计算几天内移动的每个ID的四分位数。

val w = Window.partitionBy("id").orderBy("day").rangeBetween(-2, 0)
df.select(col("id"),col("day"),collect_list(col("value")).over(w),quartiles(col("value")).over(w).as("Quartiles"))

当然没有为此提供四分位数功能,因此我需要编写一个UserDefinedAggregateFunction。以下是一个简单的(尽管是不可扩展的)解决方案(基于thisCollectionFunction

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

class QuartilesFunction extends UserDefinedAggregateFunction {
    def inputSchema: StructType =
        StructType(StructField("value", DoubleType, false) :: Nil)

    def bufferSchema: StructType = StructType(StructField("lower", ArrayType(DoubleType, true), true) :: StructField("upper", ArrayType(DoubleType, true), true) :: Nil)

    override def dataType: DataType = ArrayType(DoubleType, true)

    def deterministic: Boolean = true

    def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer(0) = IndexedSeq[Double]()
        buffer(1) = IndexedSeq[Double]()
    }

  def rebalance(lower : IndexedSeq[Double], upper : IndexedSeq[Double]) = {
    (lower++upper).splitAt((lower.length+upper.length)/2)
  }

  def sorted_median(x : IndexedSeq[Double]) : Option[Double] = { 
    if(x.length == 0) {
      None
    }
    val N = x.length
    val (lower, upper) = x.splitAt(N/2)
    Some(
            if(N%2==0) {
                (lower.last+upper.head)/2.0
            } else {
                upper.head
            }
    )
  }

  // this is how to update the buffer given an input
  def update(buffer : MutableAggregationBuffer, input : Row) : Unit = {
    val lower = buffer(0).asInstanceOf[IndexedSeq[Double]]
    val upper = buffer(1).asInstanceOf[IndexedSeq[Double]]
    val value = input.getAs[Double](0)
    if(lower.length == 0) {
      buffer(0) = Array(value)
    } else {
      if(value >= lower.last) {
        buffer(1) = (value +: upper).sortWith(_<_)
      } else {
        buffer(0) = (lower :+ value).sortWith(_<_)
      }
    }
    val (result0,result1) = rebalance(buffer(0).asInstanceOf[IndexedSeq[Double]],buffer(1).asInstanceOf[IndexedSeq[Double]])
    buffer(0) = result0
    buffer(1) = result1
  }

  // this is how to merge two objects with the buffer schema type
  def merge(buffer1 : MutableAggregationBuffer, buffer2 : Row) : Unit = {
    buffer1(0) = buffer1(0).asInstanceOf[IndexedSeq[Double]] ++ buffer2(0).asInstanceOf[IndexedSeq[Double]]
    buffer1(1) = buffer1(1).asInstanceOf[IndexedSeq[Double]] ++ buffer2(1).asInstanceOf[IndexedSeq[Double]]
    val (result0,result1) = rebalance(buffer1(0).asInstanceOf[IndexedSeq[Double]],buffer1(1).asInstanceOf[IndexedSeq[Double]])
    buffer1(0) = result0
    buffer1(1) = result1
  }

    def evaluate(buffer: Row): Array[Option[Double]] = {
        val lower = 
        if (buffer(0) == null) {
                IndexedSeq[Double]()
        } else {
                buffer(0).asInstanceOf[IndexedSeq[Double]]
        }
        val upper = 
        if (buffer(1) == null) {
                IndexedSeq[Double]()
        } else {
                buffer(1).asInstanceOf[IndexedSeq[Double]]
        }
        val Q1 = sorted_median(lower)
        val Q2 = if(upper.length==0) { None } else { Some(upper.head) }
        val Q3 = sorted_median(upper)
        Array(Q1,Q2,Q3)
    }
}

但是,执行以下操作会产生错误:

val quartiles = new QuartilesFunction
df.select('*).show
val w = org.apache.spark.sql.expressions.Window.partitionBy("id").orderBy("day").rangeBetween(-2, 0)
val x = df.select(col("id"),col("day"),collect_list(col("value")).over(w),quartiles(col("value")).over(w).as("Quantiles"))
x.show

错误是:

org.apache.spark.SparkException: Task not serializable

冒犯的功能似乎是sorted_median。如果我将代码替换为:

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

class QuartilesFunction extends UserDefinedAggregateFunction {
    def inputSchema: StructType =
        StructType(StructField("value", DoubleType, false) :: Nil)

    def bufferSchema: StructType = StructType(StructField("lower", ArrayType(DoubleType, true), true) :: StructField("upper", ArrayType(DoubleType, true), true) :: Nil)

    override def dataType: DataType = ArrayType(DoubleType, true)

    def deterministic: Boolean = true

    def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer(0) = IndexedSeq[Double]()
        buffer(1) = IndexedSeq[Double]()
    }

  def rebalance(lower : IndexedSeq[Double], upper : IndexedSeq[Double]) = {
    (lower++upper).splitAt((lower.length+upper.length)/2)
  }
/*
  def sorted_median(x : IndexedSeq[Double]) : Option[Double] = { 
    if(x.length == 0) {
      None
    }
    val N = x.length
    val (lower, upper) = x.splitAt(N/2)
    Some(
            if(N%2==0) {
                (lower.last+upper.head)/2.0
            } else {
                upper.head
            }
    )
  }
*/
  // this is how to update the buffer given an input
  def update(buffer : MutableAggregationBuffer, input : Row) : Unit = {
    val lower = buffer(0).asInstanceOf[IndexedSeq[Double]]
    val upper = buffer(1).asInstanceOf[IndexedSeq[Double]]
    val value = input.getAs[Double](0)
    if(lower.length == 0) {
      buffer(0) = Array(value)
    } else {
      if(value >= lower.last) {
        buffer(1) = (value +: upper).sortWith(_<_)
      } else {
        buffer(0) = (lower :+ value).sortWith(_<_)
      }
    }
    val (result0,result1) = rebalance(buffer(0).asInstanceOf[IndexedSeq[Double]],buffer(1).asInstanceOf[IndexedSeq[Double]])
    buffer(0) = result0
    buffer(1) = result1
  }

  // this is how to merge two objects with the buffer schema type
  def merge(buffer1 : MutableAggregationBuffer, buffer2 : Row) : Unit = {
    buffer1(0) = buffer1(0).asInstanceOf[IndexedSeq[Double]] ++ buffer2(0).asInstanceOf[IndexedSeq[Double]]
    buffer1(1) = buffer1(1).asInstanceOf[IndexedSeq[Double]] ++ buffer2(1).asInstanceOf[IndexedSeq[Double]]
    val (result0,result1) = rebalance(buffer1(0).asInstanceOf[IndexedSeq[Double]],buffer1(1).asInstanceOf[IndexedSeq[Double]])
    buffer1(0) = result0
    buffer1(1) = result1
  }

    def evaluate(buffer: Row): Array[Option[Double]] = {
        val lower = 
        if (buffer(0) == null) {
                IndexedSeq[Double]()
        } else {
                buffer(0).asInstanceOf[IndexedSeq[Double]]
        }
        val upper = 
        if (buffer(1) == null) {
                IndexedSeq[Double]()
        } else {
                buffer(1).asInstanceOf[IndexedSeq[Double]]
        }
        val Q1 = Some(1.0)//sorted_median(lower)
        val Q2 = Some(2.0)//if(upper.length==0) { None } else { Some(upper.head) }
        val Q3 = Some(3.0)//sorted_median(upper)
        Array(Q1,Q2,Q3)
    }
}

然后一切正常,除了它不计算四分位数(显然)。我不明白该错误,并且堆栈跟踪的其余部分不再发光。有人可以帮助我了解问题是什么和/或如何计算这些四分位数吗?

1 个答案:

答案 0 :(得分:0)

如果您有一个配置单元上下文(或hiveSupportEnabled),则可以按以下方式使用percentile UDAF:

val dfQuartiles = df.select(
  col("id"),
  col("day"),
  collect_list(col("value")).over(w).as("values"),
  callUDF("percentile", col("value"), lit(0.25)).over(w).as("Q1"),
  callUDF("percentile", col("value"), lit(0.50)).over(w).as("Q2"),
  callUDF("percentile", col("value"), lit(0.75)).over(w).as("Q3"),
  callUDF("percentile", col("value"), lit(1.0)).over(w).as("Q4")
)

或者,您也可以使用UDF从values计算四分位数(因为您仍然拥有此数组):

val calcPercentile = udf((xs:Seq[Int], percentile:Double) => {
  val ss = xs.toSeq.sorted 
  val index = ((ss.size-1)*percentile).toInt
  ss(index)
} 
)

val dfQuartiles = df.select(
  col("id"),
  col("day"),
  collect_list(col("value")).over(w).as("values")
)
.withColumn("Q1",calcPercentile($"values",lit(0.25)))
.withColumn("Q2",calcPercentile($"values",lit(0.50)))
.withColumn("Q3",calcPercentile($"values",lit(0.75)))
.withColumn("Q4",calcPercentile($"values",lit(1.00)))