在下面的示例中,代码产生了一个计算,该计算被系统地应用于同一组原始记录。 相反,代码必须使用先前计算出的值来产生后续数量。
package playground
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{KeyValueGroupedDataset, SparkSession}
object basic2 extends App {
Logger.getLogger("org").setLevel(Level.OFF)
Logger.getLogger("akka").setLevel(Level.OFF)
val spark = SparkSession
.builder()
.appName("Sample app")
.master("local")
.getOrCreate()
import spark.implicits._
final case class Owner(car: String, pcode: String, qtty: Double)
final case class Invoice(car: String, pcode: String, qtty: Double)
val data = Seq(
Owner("A", "666", 80),
Owner("B", "555", 20),
Owner("A", "444", 50),
Owner("A", "222", 20),
Owner("C", "444", 20),
Owner("C", "666", 80),
Owner("C", "555", 120),
Owner("A", "888", 100)
)
val fleet = Seq(Invoice("A", "666", 15), Invoice("A", "888", 12))
val owners = spark.createDataset(data)
val invoices = spark.createDataset(fleet)
val gb: KeyValueGroupedDataset[Invoice, (Owner, Invoice)] = owners
.joinWith(invoices, invoices("car") === owners("car"), "inner")
.groupByKey(_._2)
gb.flatMapGroups {
case (fleet, group) ⇒
val subOwner: Vector[Owner] = group.toVector.map(_._1)
val calculatedRes = subOwner.filter(_.car == fleet.car)
calculatedRes.map(c => c.copy(qtty = .3 * c.qtty + fleet.qtty))
}
.show()
}
/**
* +---+-----+----+
* |car|pcode|qtty|
* +---+-----+----+
* | A| 666|39.0|
* | A| 444|30.0|
* | A| 222|21.0|
* | A| 888|45.0|
* | A| 666|36.0|
* | A| 444|27.0|
* | A| 222|18.0|
* | A| 888|42.0|
* +---+-----+----+
*
* +---+-----+----+
* |car|pcode|qtty|
* +---+-----+----+
* | A| 666|0.3 * 39.0 + 12|
* | A| 444|0.3 * 30.0 + 12|
* | A| 222|0.3 * 21.0 + 12|
* | A| 888|0.3 * 45.0 + 12|
* +---+-----+----+
*/
上面的第二张表显示了预期的输出。第一个表是此问题的代码生成的。
如何以迭代方式产生期望的输出?
请注意,计算顺序无关紧要,结果会有所不同,但这仍然是有效的答案。
答案 0 :(得分:0)
检查以下代码。
val getQtty = udf((invoicesQtty:Seq[Double],ownersQtty:Double) => {
invoicesQtty.tail.foldLeft((0.3 * ownersQtty + invoicesQtty.head))(
(totalIQ,nextInvoiceQtty) => 0.3 * totalIQ + nextInvoiceQtty
)
})
val getQttyStr = udf((invoicesQtty:Seq[Double],ownersQtty:Double) => {
val totalIQ = (0.3 * ownersQtty + invoicesQtty.head)
invoicesQtty.tail.foldLeft("")(
(data,nextInvoiceQtty) => {
s"0.3 * ${if(data.isEmpty) totalIQ else s"(${data})"} + ${nextInvoiceQtty}"
}
)
})
owners
.join(invoices, invoices("car") === owners("car"), "inner")
.orderBy(invoices("qtty").desc)
.groupBy(owners("car"),owners("pcode"))
.agg(
collect_list(invoices("qtty")).as("invoices_qtty"),
first(owners("qtty")).as("owners_qtty")
)
.withColumn("qtty",getQtty($"invoices_qtty",$"owners_qtty"))
.withColumn("qtty_str",getQttyStr($"invoices_qtty",$"owners_qtty"))
.show(false)
结果
+---+-----+-------------+-----------+----+-----------------+
|car|pcode|invoices_qtty|owners_qtty|qtty|qtty_str |
+---+-----+-------------+-----------+----+-----------------+
|A |666 |[15.0, 12.0] |80.0 |23.7|0.3 * 39.0 + 12.0|
|A |888 |[15.0, 12.0] |100.0 |25.5|0.3 * 45.0 + 12.0|
|A |444 |[15.0, 12.0] |50.0 |21.0|0.3 * 30.0 + 12.0|
|A |222 |[15.0, 12.0] |20.0 |18.3|0.3 * 21.0 + 12.0|
+---+-----+-------------+-----------+----+-----------------+