将Spark Scala中当前行中的前一行值相加

时间:2019-02-11 05:48:41

标签: scala apache-spark apache-spark-sql

我正在尝试根据其他数据框中的值调整列值之一。这样做时,如果剩余量更多,则需要结转到下一行并计算最终金额。

在此操作期间,我无法保留上一行剩余的金额以用于下一行操作。我尝试使用滞后窗口功能并采用运行总计选项,但这些选项未按预期运行。

我正在与Scala合作。这是输入数据

val consumption = sc.parallelize(Seq((20180101, 600), (20180201, 900),(20180301, 400),(20180401, 600),(20180501, 1000),(20180601, 1900),(20180701, 500),(20180801, 100),(20180901, 500))).toDF("Month","Usage")
consumption.show()
+--------+-----+
|   Month|Usage|
+--------+-----+
|20180101|  600|
|20180201|  900|
|20180301|  400|
|20180401|  600|
|20180501| 1000|
|20180601| 1900|
|20180701|  500|
|20180801|  100|
|20180901|  500|
+--------+-----+
val promo = sc.parallelize(Seq((20180101, 1000),(20180201, 100),(20180401, 3000))).toDF("PromoEffectiveMonth","promoAmount")
promo.show()
+-------------------+-----------+
|PromoEffectiveMonth|promoAmount|
+-------------------+-----------+
|           20180101|       1000|
|           20180201|        100|
|           20180401|       3000|
+-------------------+-----------+

预期结果:

val finaldf = sc.parallelize(Seq((20180101,600,400,600),(20180201,900,0,400),(20180301,400,0,0),(20180401,600,2400,600),(20180501,1000,1400,1000),(20180601,1900,0,500),(20180701,500,0,0),(20180801,100,0,0),(20180901,500,0,0))).toDF("Month","Usage","LeftOverPromoAmt","AdjustedUsage")
finaldf.show()
+--------+-----+----------------+-------------+
|   Month|Usage|LeftOverPromoAmt|AdjustedUsage|
+--------+-----+----------------+-------------+
|20180101|  600|             400|          600|
|20180201|  900|               0|          400|
|20180301|  400|               0|            0|
|20180401|  600|            2400|          600|
|20180501| 1000|            1400|         1000|
|20180601| 1900|               0|          500|
|20180701|  500|               0|            0|
|20180801|  100|               0|            0|
|20180901|  500|               0|            0|
+--------+-----+----------------+-------------+

我要应用的逻辑基于“月”和“促销有效联接”,需要在消费使用列上应用促销金额,直到促销金额变为零为止。

例如:在1月18日,促销金额为1000,从使用量中扣除(600)后,剩余的促销金额为400,调整后的使用量为600。超过400的剩余金额将被考虑用于下个月的促销活动2月份的amt,那么最终的促销金额为500。与使用情况相比,此处的使用情况更多。

因此促销剩余额为零,调整使用量为400(900-500)。

1 个答案:

答案 0 :(得分:4)

首先,您需要执行left_outer连接,以便对每一行都有相应的升级。通过分别来自数据集MonthPromoEffectiveMonth的字段Consumptionpromo来执行联接操作。还要注意,我已经创建了一个新列Timestamp。它是使用Spark SQL unix_timestamp函数创建的。它将用于按日期对数据集进行排序。

val ds = consumption
    .join(promo, consumption.col("Month") === promo.col("PromoEffectiveMonth"), "left_outer")
    .select("UserID", "Month", "Usage", "promoAmount")
    .withColumn("Timestamp", unix_timestamp($"Month".cast("string"), "yyyyMMdd").cast(TimestampType))

这是这些操作的结果。

+--------+-----+-----------+-------------------+
|   Month|Usage|promoAmount|          Timestamp|
+--------+-----+-----------+-------------------+
|20180301|  400|       null|2018-03-01 00:00:00|
|20180701|  500|       null|2018-07-01 00:00:00|
|20180901|  500|       null|2018-09-01 00:00:00|
|20180101|  600|       1000|2018-01-01 00:00:00|
|20180801|  100|       null|2018-08-01 00:00:00|
|20180501| 1000|       null|2018-05-01 00:00:00|
|20180201|  900|        100|2018-02-01 00:00:00|
|20180601| 1900|       null|2018-06-01 00:00:00|
|20180401|  600|       3000|2018-04-01 00:00:00|
+--------+-----+-----------+-------------------+

接下来,您必须创建一个Window。窗口函数用于通过使用某些条件对一组记录进行计算(有关此here的更多信息)。在我们的例子中,标准是按Timestamp对每个组进行排序。

 val window = Window.orderBy("Timestamp")

好的,现在是最困难的部分。您需要创建一个User Defined Aggregate Function。在此功能中,将根据自定义操作处理每个组,并使您可以考虑上一行的值来处理每一行。

  class CalculatePromos extends UserDefinedAggregateFunction {
    // Input schema for this UserDefinedAggregateFunction
    override def inputSchema: StructType =
      StructType(
        StructField("Usage", LongType) ::
        StructField("promoAmount", LongType) :: Nil)

    // Schema for the parameters that will be used internally to buffer temporary values
    override def bufferSchema: StructType = StructType(
        StructField("AdjustedUsage", LongType) ::
        StructField("LeftOverPromoAmt", LongType) :: Nil
    )

    // The data type returned by this UserDefinedAggregateFunction.
    // In this case, it will return an StructType with two fields: AdjustedUsage and LeftOverPromoAmt
    override def dataType: DataType = StructType(Seq(StructField("AdjustedUsage", LongType), StructField("LeftOverPromoAmt", LongType)))

    // Whether this UDAF is deterministic or not. In this case, it is
    override def deterministic: Boolean = true

    // Initial values for the temporary values declared above
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = 0L
      buffer(1) = 0L
    }

    // In this function, the values associated to the buffer schema are updated
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

      val promoAmount = if(input.isNullAt(1)) 0L else input.getLong(1)
      val leftOverAmount = buffer.getLong(1)
      val usage = input.getLong(0)
      val currentPromo = leftOverAmount + promoAmount

      if(usage < currentPromo) {
        buffer(0) = usage
        buffer(1) = currentPromo - usage
      } else {
        if(currentPromo == 0)
          buffer(0) = 0L
        else
          buffer(0) = usage - currentPromo
        buffer(1) = 0L
      }
    }

    // Function used to merge two objects. In this case, it is not necessary to define this method since
    // the whole logic has been implemented in update
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {}

    // It is what you will return. In this case, a tuple of the buffered values which rerpesent AdjustedUsage and LeftOverPromoAmt
    override def evaluate(buffer: Row): Any = {
      (buffer.getLong(0), buffer.getLong(1))
    }

  }

基本上,它创建一个可在Spark SQL中使用的函数,该函数接收两列(如方法Usage中指定的promoAmountinputSchema)并返回一个新列具有两个子列(AdjustedUsageLeftOverPromAmt,如方法dataType中所定义)。使用方法bufferSchema,您可以创建用于支持操作的临时值。在这种情况下,我已经定义了AdjustedUsageLeftOverPromoAmt

您正在使用的逻辑在方法update中实现。基本上,它将采用先前计算的值并进行更新。参数buffer包含在bufferSchema中定义的临时值,并且input保留该时刻正在处理的行的值。最后,evaluate返回一个元组对象,其中包含每一行的操作结果,在本例中为bufferSchema中定义并在方法update中更新的临时值。

下一步是通过实例化类CalculatePromos来创建变量。

val calculatePromos = new CalculatePromos

最后,您必须使用数据集的方法calculatePromos应用用户定义的聚合函数withColumn。请注意,您必须将输入列(UsagepromoAmount传递给它,然后使用方法来应用窗口。

ds
  .withColumn("output", calculatePromos($"Usage", $"promoAmount").over(window))
  .select($"Month", $"Usage", $"output.LeftOverPromoAmt".as("LeftOverPromoAmt"), $"output.AdjustedUsage".as("AdjustedUsage"))

这是结果:

+--------+-----+----------------+-------------+
|   Month|Usage|LeftOverPromoAmt|AdjustedUsage|
+--------+-----+----------------+-------------+
|20180101|  600|             400|          600|
|20180201|  900|               0|          400|
|20180301|  400|               0|            0|
|20180401|  600|            2400|          600|
|20180501| 1000|            1400|         1000|
|20180601| 1900|               0|          500|
|20180701|  500|               0|            0|
|20180801|  100|               0|            0|
|20180901|  500|               0|            0|
+--------+-----+----------------+-------------+

希望有帮助。