如果一列为'relative'
,我想对值求和,如果是'absolute'
,则重新启动总和
这里我定义了我的dataFrame:
val df = sc.parallelize(Seq(
(1, "2018-02-21", 'relative, 3.00),
(1, "2018-02-22", 'relative, 4.00),
(1, "2018-02-23", 'absolute, 5.00),
(1, "2018-02-24", 'relative, 6.00),
(1, "2018-02-26", 'relative, 8.00)
)).toDF("id", "date", "updateType", "value")
我定义了一个UDF来知道何时求和以及何时不求和。我想按日期排序,然后在必要时对值求和或者设置绝对值
val computeValue = udf((previous: java.math.BigDecimal, value: java.math.BigDecimal, updateType: String) => {
updateType match {
case "absolute" => value
case "relative" => previous.add(value)
case _ => previous
}
})
val w = Window
.partitionBy($"id")
.orderBy($"date")
val result = df.select(
$"id",
$"date",
computeValue(
lag($"value", 1, 0).over(w),
$"value",
$"updateType"
).alias("sumValue")
)
这实际上会返回:
+---+----------+---------+
| id| date| sumValue|
+---+----------+---------+
| 1|2018-02-21|3.000 |
| 1|2018-02-22|7.000 |
| 1|2018-02-23|5.00 |
| 1|2018-02-24|11.00 |
| 1|2018-02-26|14.00 |
+---+----------+---------+
我正在寻找:
+---+----------+---------+
| id| date| sumValue|
+---+----------+---------+
| 1|2018-02-21|3.000 |
| 1|2018-02-22|7.000 |
| 1|2018-02-23|5.00 |
| 1|2018-02-24|11.00 |
| 1|2018-02-26|19.00 |
+---+----------+---------+
答案 0 :(得分:1)
答案是使用UDAF(用户定义聚合函数)进行此类操作。
// Init aggregation function to compute values
val computeValue = new ComputeValue
val w = Window
.partitionBy($"id")
.orderBy($"date")
val result = df.select(
$"id",
$"date",
computeValue(
$"value",
$"updateType"
).over(w).alias("sumValue")
)
ComputeValue UDAF在哪里:
class ComputeValue extends UserDefinedAggregateFunction {
// Each row will be of type value: Double - update_type: String
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(
StructField("value", DoubleType) ::
StructField("update_type", StringType) :: Nil)
// Another column where I will keep internal calculations
override def bufferSchema: StructType = StructType(
StructField("value", DoubleType) :: Nil
)
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = 0.0
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = computeValue(buffer, input)
}
// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = computeValue(buffer1, buffer2)
}
// Get the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
buffer.getDouble(0)
}
private def computeValue(buffer: MutableAggregationBuffer, row: Row): Double = {
val updateType: String = row.getAs[String](1)
val prev: Double = buffer.getDouble(0)
val current: Double = row.getAs[Double](0)
updateType match {
case "relative" => prev + current
case "absolute" => current
case _ => current
}
}
}