我有一些数据,为了便于讨论,将其提供给:
val schema = Seq("id", "day", "value")
val data = Seq(
(1, 1, 1),
(1, 2, 11),
(1, 3, 1),
(1, 4, 11),
(1, 5, 1),
(1, 6, 11),
(2, 1, 1),
(2, 2, 11),
(2, 3, 1),
(2, 4, 11),
(2, 5, 1),
(2, 6, 11)
)
val df = sc.parallelize(data).toDF(schema: _*)
我想计算几天内移动的每个ID的四分位数。
val w = Window.partitionBy("id").orderBy("day").rangeBetween(-2, 0)
df.select(col("id"),col("day"),collect_list(col("value")).over(w),quartiles(col("value")).over(w).as("Quartiles"))
当然没有为此提供四分位数功能,因此我需要编写一个UserDefinedAggregateFunction
。以下是一个简单的(尽管是不可扩展的)解决方案(基于this)CollectionFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
class QuartilesFunction extends UserDefinedAggregateFunction {
def inputSchema: StructType =
StructType(StructField("value", DoubleType, false) :: Nil)
def bufferSchema: StructType = StructType(StructField("lower", ArrayType(DoubleType, true), true) :: StructField("upper", ArrayType(DoubleType, true), true) :: Nil)
override def dataType: DataType = ArrayType(DoubleType, true)
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = IndexedSeq[Double]()
buffer(1) = IndexedSeq[Double]()
}
def rebalance(lower : IndexedSeq[Double], upper : IndexedSeq[Double]) = {
(lower++upper).splitAt((lower.length+upper.length)/2)
}
def sorted_median(x : IndexedSeq[Double]) : Option[Double] = {
if(x.length == 0) {
None
}
val N = x.length
val (lower, upper) = x.splitAt(N/2)
Some(
if(N%2==0) {
(lower.last+upper.head)/2.0
} else {
upper.head
}
)
}
// this is how to update the buffer given an input
def update(buffer : MutableAggregationBuffer, input : Row) : Unit = {
val lower = buffer(0).asInstanceOf[IndexedSeq[Double]]
val upper = buffer(1).asInstanceOf[IndexedSeq[Double]]
val value = input.getAs[Double](0)
if(lower.length == 0) {
buffer(0) = Array(value)
} else {
if(value >= lower.last) {
buffer(1) = (value +: upper).sortWith(_<_)
} else {
buffer(0) = (lower :+ value).sortWith(_<_)
}
}
val (result0,result1) = rebalance(buffer(0).asInstanceOf[IndexedSeq[Double]],buffer(1).asInstanceOf[IndexedSeq[Double]])
buffer(0) = result0
buffer(1) = result1
}
// this is how to merge two objects with the buffer schema type
def merge(buffer1 : MutableAggregationBuffer, buffer2 : Row) : Unit = {
buffer1(0) = buffer1(0).asInstanceOf[IndexedSeq[Double]] ++ buffer2(0).asInstanceOf[IndexedSeq[Double]]
buffer1(1) = buffer1(1).asInstanceOf[IndexedSeq[Double]] ++ buffer2(1).asInstanceOf[IndexedSeq[Double]]
val (result0,result1) = rebalance(buffer1(0).asInstanceOf[IndexedSeq[Double]],buffer1(1).asInstanceOf[IndexedSeq[Double]])
buffer1(0) = result0
buffer1(1) = result1
}
def evaluate(buffer: Row): Array[Option[Double]] = {
val lower =
if (buffer(0) == null) {
IndexedSeq[Double]()
} else {
buffer(0).asInstanceOf[IndexedSeq[Double]]
}
val upper =
if (buffer(1) == null) {
IndexedSeq[Double]()
} else {
buffer(1).asInstanceOf[IndexedSeq[Double]]
}
val Q1 = sorted_median(lower)
val Q2 = if(upper.length==0) { None } else { Some(upper.head) }
val Q3 = sorted_median(upper)
Array(Q1,Q2,Q3)
}
}
但是,执行以下操作会产生错误:
val quartiles = new QuartilesFunction
df.select('*).show
val w = org.apache.spark.sql.expressions.Window.partitionBy("id").orderBy("day").rangeBetween(-2, 0)
val x = df.select(col("id"),col("day"),collect_list(col("value")).over(w),quartiles(col("value")).over(w).as("Quantiles"))
x.show
错误是:
org.apache.spark.SparkException: Task not serializable
冒犯的功能似乎是sorted_median
。如果我将代码替换为:
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
class QuartilesFunction extends UserDefinedAggregateFunction {
def inputSchema: StructType =
StructType(StructField("value", DoubleType, false) :: Nil)
def bufferSchema: StructType = StructType(StructField("lower", ArrayType(DoubleType, true), true) :: StructField("upper", ArrayType(DoubleType, true), true) :: Nil)
override def dataType: DataType = ArrayType(DoubleType, true)
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = IndexedSeq[Double]()
buffer(1) = IndexedSeq[Double]()
}
def rebalance(lower : IndexedSeq[Double], upper : IndexedSeq[Double]) = {
(lower++upper).splitAt((lower.length+upper.length)/2)
}
/*
def sorted_median(x : IndexedSeq[Double]) : Option[Double] = {
if(x.length == 0) {
None
}
val N = x.length
val (lower, upper) = x.splitAt(N/2)
Some(
if(N%2==0) {
(lower.last+upper.head)/2.0
} else {
upper.head
}
)
}
*/
// this is how to update the buffer given an input
def update(buffer : MutableAggregationBuffer, input : Row) : Unit = {
val lower = buffer(0).asInstanceOf[IndexedSeq[Double]]
val upper = buffer(1).asInstanceOf[IndexedSeq[Double]]
val value = input.getAs[Double](0)
if(lower.length == 0) {
buffer(0) = Array(value)
} else {
if(value >= lower.last) {
buffer(1) = (value +: upper).sortWith(_<_)
} else {
buffer(0) = (lower :+ value).sortWith(_<_)
}
}
val (result0,result1) = rebalance(buffer(0).asInstanceOf[IndexedSeq[Double]],buffer(1).asInstanceOf[IndexedSeq[Double]])
buffer(0) = result0
buffer(1) = result1
}
// this is how to merge two objects with the buffer schema type
def merge(buffer1 : MutableAggregationBuffer, buffer2 : Row) : Unit = {
buffer1(0) = buffer1(0).asInstanceOf[IndexedSeq[Double]] ++ buffer2(0).asInstanceOf[IndexedSeq[Double]]
buffer1(1) = buffer1(1).asInstanceOf[IndexedSeq[Double]] ++ buffer2(1).asInstanceOf[IndexedSeq[Double]]
val (result0,result1) = rebalance(buffer1(0).asInstanceOf[IndexedSeq[Double]],buffer1(1).asInstanceOf[IndexedSeq[Double]])
buffer1(0) = result0
buffer1(1) = result1
}
def evaluate(buffer: Row): Array[Option[Double]] = {
val lower =
if (buffer(0) == null) {
IndexedSeq[Double]()
} else {
buffer(0).asInstanceOf[IndexedSeq[Double]]
}
val upper =
if (buffer(1) == null) {
IndexedSeq[Double]()
} else {
buffer(1).asInstanceOf[IndexedSeq[Double]]
}
val Q1 = Some(1.0)//sorted_median(lower)
val Q2 = Some(2.0)//if(upper.length==0) { None } else { Some(upper.head) }
val Q3 = Some(3.0)//sorted_median(upper)
Array(Q1,Q2,Q3)
}
}
然后一切正常,除了它不计算四分位数(显然)。我不明白该错误,并且堆栈跟踪的其余部分不再发光。有人可以帮助我了解问题是什么和/或如何计算这些四分位数吗?
答案 0 :(得分:0)
如果您有一个配置单元上下文(或hiveSupportEnabled
),则可以按以下方式使用percentile
UDAF:
val dfQuartiles = df.select(
col("id"),
col("day"),
collect_list(col("value")).over(w).as("values"),
callUDF("percentile", col("value"), lit(0.25)).over(w).as("Q1"),
callUDF("percentile", col("value"), lit(0.50)).over(w).as("Q2"),
callUDF("percentile", col("value"), lit(0.75)).over(w).as("Q3"),
callUDF("percentile", col("value"), lit(1.0)).over(w).as("Q4")
)
或者,您也可以使用UDF从values
计算四分位数(因为您仍然拥有此数组):
val calcPercentile = udf((xs:Seq[Int], percentile:Double) => {
val ss = xs.toSeq.sorted
val index = ((ss.size-1)*percentile).toInt
ss(index)
}
)
val dfQuartiles = df.select(
col("id"),
col("day"),
collect_list(col("value")).over(w).as("values")
)
.withColumn("Q1",calcPercentile($"values",lit(0.25)))
.withColumn("Q2",calcPercentile($"values",lit(0.50)))
.withColumn("Q3",calcPercentile($"values",lit(0.75)))
.withColumn("Q4",calcPercentile($"values",lit(1.00)))