求和值并重新启动火花窗口函数中的条件

时间:2018-01-10 17:50:19

标签: scala apache-spark-sql spark-dataframe

如果一列为'relative',我想对值求和,如果是'absolute',则重新启动总和

这里我定义了我的dataFrame:

val df = sc.parallelize(Seq(
  (1, "2018-02-21", 'relative, 3.00),
  (1, "2018-02-22", 'relative, 4.00),
  (1, "2018-02-23", 'absolute, 5.00),
  (1, "2018-02-24", 'relative, 6.00),
  (1, "2018-02-26", 'relative, 8.00)
)).toDF("id", "date", "updateType", "value")

我定义了一个UDF来知道何时求和以及何时不求和。我想按日期排序,然后在必要时对值求和或者设置绝对值

val computeValue = udf((previous: java.math.BigDecimal, value: java.math.BigDecimal, updateType: String) => {
  updateType match {
    case "absolute" => value
    case "relative" => previous.add(value)
    case _ => previous
  }
})
val w = Window
  .partitionBy($"id")
  .orderBy($"date")

val result = df.select(
  $"id",
  $"date",
  computeValue(
    lag($"value", 1, 0).over(w),
    $"value",
    $"updateType"
  ).alias("sumValue")
)

这实际上会返回:

+---+----------+---------+
| id|      date| sumValue|
+---+----------+---------+
|  1|2018-02-21|3.000    |
|  1|2018-02-22|7.000    |
|  1|2018-02-23|5.00     |
|  1|2018-02-24|11.00    |
|  1|2018-02-26|14.00    |
+---+----------+---------+

我正在寻找:

+---+----------+---------+
| id|      date| sumValue|
+---+----------+---------+
|  1|2018-02-21|3.000    |
|  1|2018-02-22|7.000    |
|  1|2018-02-23|5.00     |
|  1|2018-02-24|11.00    |
|  1|2018-02-26|19.00    |
+---+----------+---------+

1 个答案:

答案 0 :(得分:1)

答案是使用UDAF(用户定义聚合函数)进行此类操作。

// Init aggregation function to compute values
val computeValue = new ComputeValue
val w = Window
  .partitionBy($"id")
  .orderBy($"date")

val result = df.select(
  $"id",
  $"date",
  computeValue(
    $"value",
    $"updateType"
  ).over(w).alias("sumValue")
)

ComputeValue UDAF在哪里:

class ComputeValue extends UserDefinedAggregateFunction {

  // Each row will be of type value: Double - update_type: String
  override def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(
      StructField("value", DoubleType) ::
        StructField("update_type", StringType) :: Nil)

  // Another column where I will keep internal calculations
  override def bufferSchema: StructType = StructType(
    StructField("value", DoubleType) :: Nil
  )

  override def dataType: DataType = DoubleType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = 0.0

  // This is how to update your buffer schema given an input.
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = computeValue(buffer, input)
  }

  // This is how to merge two objects with the bufferSchema type.
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = computeValue(buffer1, buffer2)
  }

  // Get the final value of your bufferSchema.
  override def evaluate(buffer: Row): Any = {
    buffer.getDouble(0)
  }

  private def computeValue(buffer: MutableAggregationBuffer, row: Row): Double = {
    val updateType: String = row.getAs[String](1)
    val prev: Double = buffer.getDouble(0)
    val current: Double = row.getAs[Double](0)

    updateType match {
      case "relative" => prev + current
      case "absolute" => current
      case _ => current
    }
  }
}