如何使用Spark计算限制下的累计和?

时间:2020-03-04 22:38:48

标签: sql scala apache-spark

经过几次尝试和研究,我坚持尝试使用Spark解决以下问题。

我有一个具有优先级和数量的元素的数据框。

+------+-------+--------+---+
|family|element|priority|qty|
+------+-------+--------+---+
|    f1| elmt 1|       1| 20|
|    f1| elmt 2|       2| 40|
|    f1| elmt 3|       3| 10|
|    f1| elmt 4|       4| 50|
|    f1| elmt 5|       5| 40|
|    f1| elmt 6|       6| 10|
|    f1| elmt 7|       7| 20|
|    f1| elmt 8|       8| 10|
+------+-------+--------+---+

我有固定的限制数量:

+------+--------+
|family|limitQty|
+------+--------+
|    f1|     100|
+------+--------+

我要将累积总和低于限制的元素标记为“确定”。这是预期的结果:

+------+-------+--------+---+---+
|family|element|priority|qty| ok|
+------+-------+--------+---+---+
|    f1| elmt 1|       1| 20|  1| -> 20 < 100   => ok
|    f1| elmt 2|       2| 40|  1| -> 20 + 40 < 100  => ok
|    f1| elmt 3|       3| 10|  1| -> 20 + 40 + 10 < 100   => ok
|    f1| elmt 4|       4| 50|  0| -> 20 + 40 + 10 + 50 > 100   => ko 
|    f1| elmt 5|       5| 40|  0| -> 20 + 40 + 10 + 40 > 100   => ko  
|    f1| elmt 6|       6| 10|  1| -> 20 + 40 + 10 + 10 < 100   => ok
|    f1| elmt 7|       7| 20|  1| -> 20 + 40 + 10 + 10 + 20 < 100   => ok
|    f1| elmt 8|       8| 10|  0| -> 20 + 40 + 10 + 10 + 20 + 10 > 100   => ko
+------+-------+--------+---+---+  

我尝试解决是否具有累计和:

    initDF
      .join(limitQtyDF, Seq("family"), "left_outer")
      .withColumn("cumulSum", sum($"qty").over(Window.partitionBy("family").orderBy("priority")))
      .withColumn("ok", when($"cumulSum" <= $"limitQty", 1).otherwise(0))
      .drop("cumulSum", "limitQty")

但这是不够的,因为没有考虑到元素之后的元素。 我找不到用Spark解决它的方法。你有个主意吗?

以下是对应的Scala代码:

    val sparkSession = SparkSession.builder()
      .master("local[*]")
      .getOrCreate()

    import sparkSession.implicits._

    val initDF = Seq(
      ("f1", "elmt 1", 1, 20),
      ("f1", "elmt 2", 2, 40),
      ("f1", "elmt 3", 3, 10),
      ("f1", "elmt 4", 4, 50),
      ("f1", "elmt 5", 5, 40),
      ("f1", "elmt 6", 6, 10),
      ("f1", "elmt 7", 7, 20),
      ("f1", "elmt 8", 8, 10)
    ).toDF("family", "element", "priority", "qty")

    val limitQtyDF = Seq(("f1", 100)).toDF("family", "limitQty")

    val expectedDF = Seq(
      ("f1", "elmt 1", 1, 20, 1),
      ("f1", "elmt 2", 2, 40, 1),
      ("f1", "elmt 3", 3, 10, 1),
      ("f1", "elmt 4", 4, 50, 0),
      ("f1", "elmt 5", 5, 40, 0),
      ("f1", "elmt 6", 6, 10, 1),
      ("f1", "elmt 7", 7, 20, 1),
      ("f1", "elmt 8", 8, 10, 0)
    ).toDF("family", "element", "priority", "qty", "ok").show()

谢谢您的帮助!

5 个答案:

答案 0 :(得分:1)

解决方案如下所示:

@app.route("/webhook", methods = ["GET", "POST"])

让我知道是否有帮助!

答案 1 :(得分:0)

另一种实现方法是逐行迭代基于RDD的方法。

var bufferRow: collection.mutable.Buffer[Row] = collection.mutable.Buffer.empty[Row]
var tempSum: Double = 0
val iterator = df.collect.iterator
while(iterator.hasNext){
  val record = iterator.next()
  val y = record.getAs[Integer]("qty")
  tempSum = tempSum + y
  print(record)
  if (tempSum <= 100.0 ) {
    bufferRow = bufferRow ++ Seq(transformRow(record,1))
  }
  else{
    bufferRow = bufferRow ++ Seq(transformRow(record,0))
    tempSum = tempSum - y
  }
}

定义transformRow函数,该函数用于向行添加列。

def transformRow(row: Row,flag : Int): Row =  Row.fromSeq(row.toSeq ++ Array[Integer](flag))

接下来要做的是在架构中添加额外的列。

val newSchema = StructType(df.schema.fields ++ Array(StructField("C_Sum", IntegerType, false))

接着创建一个新的数据框。

val outputdf = spark.createDataFrame(spark.sparkContext.parallelize(bufferRow.toSeq),newSchema)

输出数据框:

+------+-------+--------+---+-----+
|family|element|priority|qty|C_Sum|
+------+-------+--------+---+-----+
|    f1|  elmt1|       1| 20|    1|
|    f1|  elmt2|       2| 40|    1|
|    f1|  elmt3|       3| 10|    1|
|    f1|  elmt4|       4| 50|    0|
|    f1|  elmt5|       5| 40|    0|
|    f1|  elmt6|       6| 10|    1|
|    f1|  elmt7|       7| 20|    1|
|    f1|  elmt8|       8| 10|    0|
+------+-------+--------+---+-----+

答案 2 :(得分:0)

我是Spark的新手,因此此解决方案可能不是最佳解决方案。我假设100的值是此处程序的输入。在这种情况下:

case class Frame(family:String, element : String, priority : Int, qty :Int)

import scala.collection.JavaConverters._
val ans = df.as[Frame].toLocalIterator
  .asScala
  .foldLeft((Seq.empty[Int],0))((acc,a) => 
    if(acc._2 + a.qty <= 100) (acc._1 :+ a.priority, acc._2 + a.qty) else acc)._1

df.withColumn("OK" , when($"priority".isin(ans :_*), 1).otherwise(0)).show

导致:

+------+-------+--------+---+--------+
|family|element|priority|qty|OK      |
+------+-------+--------+---+--------+
|    f1| elmt 1|       1| 20|       1|
|    f1| elmt 2|       2| 40|       1|
|    f1| elmt 3|       3| 10|       1|
|    f1| elmt 4|       4| 50|       0|
|    f1| elmt 5|       5| 40|       0|
|    f1| elmt 6|       6| 10|       1|
|    f1| elmt 7|       7| 20|       1|
|    f1| elmt 8|       8| 10|       0|
+------+-------+--------+---+--------+

这个想法仅仅是获得一个Scala迭代器并从中提取参与的priority值,然后使用这些值来筛选出参与的行。鉴于此解决方案可以在一台计算机上收集内存中的所有数据,如果数据帧大小太大而无法容纳在内存中,则可能会遇到内存问题。

答案 3 :(得分:0)

每组的累计金额

from pyspark.sql.window import Window as window
from pyspark.sql.types import IntegerType,StringType,FloatType,StructType,StructField,DateType
schema = StructType() \
        .add(StructField("empno",IntegerType(),True)) \
        .add(StructField("ename",StringType(),True)) \
        .add(StructField("job",StringType(),True)) \
        .add(StructField("mgr",StringType(),True)) \
        .add(StructField("hiredate",DateType(),True)) \
        .add(StructField("sal",FloatType(),True)) \
        .add(StructField("comm",StringType(),True)) \
        .add(StructField("deptno",IntegerType(),True))

emp = spark.read.csv('data/emp.csv',schema)
dept_partition = window.partitionBy(emp.deptno).orderBy(emp.sal)
emp_win = emp.withColumn("dept_cum_sal", 
                         f.sum(emp.sal).over(dept_partition.rowsBetween(window.unboundedPreceding, window.currentRow)))
emp_win.show()

结果显示如下:

+-----+------+---------+----+----------+------+-------+------+------------ 
+
|empno| ename|      job| mgr|  hiredate|   sal|   comm|deptno|dept_cum_sal|
+-----+------+---------+----+----------+------+-------+------+------------ 
+
| 7369| SMITH|    CLERK|7902|1980-12-17| 800.0|   null|    20|       800.0|
| 7876| ADAMS|    CLERK|7788|1983-01-12|1100.0|   null|    20|      1900.0|
| 7566| JONES|  MANAGER|7839|1981-04-02|2975.0|   null|    20|      4875.0|
| 7788| SCOTT|  ANALYST|7566|1982-12-09|3000.0|   null|    20|      7875.0|
| 7902|  FORD|  ANALYST|7566|1981-12-03|3000.0|   null|    20|     10875.0|
| 7934|MILLER|    CLERK|7782|1982-01-23|1300.0|   null|    10|      1300.0|
| 7782| CLARK|  MANAGER|7839|1981-06-09|2450.0|   null|    10|      3750.0|
| 7839|  KING|PRESIDENT|null|1981-11-17|5000.0|   null|    10|      8750.0|
| 7900| JAMES|    CLERK|7698|1981-12-03| 950.0|   null|    30|       950.0|
| 7521|  WARD| SALESMAN|7698|1981-02-22|1250.0| 500.00|    30|      2200.0|
| 7654|MARTIN| SALESMAN|7698|1981-09-28|1250.0|1400.00|    30|      3450.0|
| 7844|TURNER| SALESMAN|7698|1981-09-08|1500.0|   0.00|    30|      4950.0|
| 7499| ALLEN| SALESMAN|7698|1981-02-20|1600.0| 300.00|    30|      6550.0|
| 7698| BLAKE|  MANAGER|7839|1981-05-01|2850.0|   null|    30|      9400.0|
+-----+------+---------+----+----------+------+-------+------+------------+

答案 4 :(得分:0)

PFA 答案

val initDF = Seq(("f1", "elmt 1", 1, 20),("f1", "elmt 2", 2, 40),("f1", "elmt 3", 3, 10),
      ("f1", "elmt 4", 4, 50),
      ("f1", "elmt 5", 5, 40),
      ("f1", "elmt 6", 6, 10),
      ("f1", "elmt 7", 7, 20),
      ("f1", "elmt 8", 8, 10)
    ).toDF("family", "element", "priority", "qty")

val limitQtyDF = Seq(("f1", 100)).toDF("family", "limitQty")

sc.broadcast(limitQtyDF)


val joinedInitDF=initDF.join(limitQtyDF,Seq("family"),"left")

case class dataResult(family:String,element:String,priority:Int, qty:Int, comutedValue:Int, limitQty:Int,controlOut:String) 
val familyIDs=initDF.select("family").distinct.collect.map(_(0).toString).toList

def checkingUDF(inputRows:List[Row])={
var controlVarQty=0
val outputArrayBuffer=collection.mutable.ArrayBuffer[dataResult]()
val setLimit=inputRows.head.getInt(4) 
for(inputRow <- inputRows)
{
val currQty=inputRow.getInt(3) 
//val outpurForRec=
controlVarQty + currQty match {
case value if value <= setLimit => 
controlVarQty+=currQty
outputArrayBuffer+=dataResult(inputRow.getString(0),inputRow.getString(1),inputRow.getInt(2),inputRow.getInt(3),value,setLimit,"ok")
case value => 
outputArrayBuffer+=dataResult(inputRow.getString(0),inputRow.getString(1),inputRow.getInt(2),inputRow.getInt(3),value,setLimit,"ko")
}
//outputArrayBuffer+=Row(inputRow.getString(0),inputRow.getString(1),inputRow.getInt(2),inputRow.getInt(3),controlVarQty+currQty,setLimit,outpurForRec)
}
outputArrayBuffer.toList
}

val tmpAB=collection.mutable.ArrayBuffer[List[dataResult]]()
for (familyID <- familyIDs) // val familyID="f1"
{
val currentFamily=joinedInitDF.filter(s"family = '${familyID}'").orderBy("element", "priority").collect.toList
tmpAB+=checkingUDF(currentFamily)
}

tmpAB.toSeq.flatMap(x => x).toDF.show(false)

这对我有用。

+------+-------+--------+---+------------+--------+----------+
|family|element|priority|qty|comutedValue|limitQty|controlOut|
+------+-------+--------+---+------------+--------+----------+
|f1    |elmt 1 |1       |20 |20          |100     |ok        |
|f1    |elmt 2 |2       |40 |60          |100     |ok        |
|f1    |elmt 3 |3       |10 |70          |100     |ok        |
|f1    |elmt 4 |4       |50 |120         |100     |ko        |
|f1    |elmt 5 |5       |40 |110         |100     |ko        |
|f1    |elmt 6 |6       |10 |80          |100     |ok        |
|f1    |elmt 7 |7       |20 |100         |100     |ok        |
|f1    |elmt 8 |8       |10 |110         |100     |ko        |
+------+-------+--------+---+------------+--------+----------+

请从输出中删除不必要的列