我试图在多列上聚合数据帧。我知道聚合所需的一切都在分区内 - 也就是说,不需要随机播放,因为聚合的所有数据都是分区的本地数据。
如果我有类似
的话,请example val sales=sc.parallelize(List(
("West", "Apple", 2.0, 10),
("West", "Apple", 3.0, 15),
("West", "Orange", 5.0, 15),
("South", "Orange", 3.0, 9),
("South", "Orange", 6.0, 18),
("East", "Milk", 5.0, 5))).repartition(2)
val tdf = sales.map{ case (store, prod, amt, units) => ((store, prod), (amt, amt, amt, units)) }.
reduceByKey((x, y) => (x._1 + y._1, math.min(x._2, y._2), math.max(x._3, y._3), x._4 + y._4))
println(tdf.toDebugString)
我得到的结果如
(2) ShuffledRDD[12] at reduceByKey at Test.scala:59 []
+-(2) MapPartitionsRDD[11] at map at Test.scala:58 []
| MapPartitionsRDD[10] at repartition at Test.scala:57 []
| CoalescedRDD[9] at repartition at Test.scala:57 []
| ShuffledRDD[8] at repartition at Test.scala:57 []
+-(1) MapPartitionsRDD[7] at repartition at Test.scala:57 []
| ParallelCollectionRDD[6] at parallelize at Test.scala:51 []
你可以看到MapPartitionsRDD,这很好。但接下来是ShuffleRDD,我想阻止它,因为我想要按分区内的列值分组的每分区摘要。
zero323' suggestion非常接近,但我需要"分组列#34;功能。
参考上面的示例,我正在寻找由
产生的结果select store, prod, sum(amt), avg(units) from sales group by partition_id, store, prod
(我真的不需要分区ID - 只是为了说明我想要每个分区的结果)
我已查看lots of examples,但我生成的每个调试字符串都有Shuffle。我真的希望摆脱洗牌。我想我实际上是在寻找groupByKeysWithinPartitions函数。
答案 0 :(得分:2)
实现这一目标的唯一方法是使用mapPartitions并使用自定义代码在迭代分区时对值进行分组和计算。 正如您所提到的,数据已经按分组键(store,prod)排序,我们可以以流水线方式高效地计算您的聚合:
(1)定义辅助类:
:paste
case class MyRec(store: String, prod: String, amt: Double, units: Int)
case class MyResult(store: String, prod: String, total_amt: Double, min_amt: Double, max_amt: Double, total_units: Int)
object MyResult {
def apply(rec: MyRec): MyResult = new MyResult(rec.store, rec.prod, rec.amt, rec.amt, rec.amt, rec.units)
def aggregate(result: MyResult, rec: MyRec) = {
new MyResult(result.store,
result.prod,
result.total_amt + rec.amt,
math.min(result.min_amt, rec.amt),
math.max(result.max_amt, rec.amt),
result.total_units + rec.units
)
}
}
(2)定义流水线聚合器:
:paste
def pipelinedAggregator(iter: Iterator[MyRec]): Iterator[Seq[MyResult]] = {
var prev: MyResult = null
var res: Seq[MyResult] = Nil
for (crt <- iter) yield {
if (prev == null) {
prev = MyResult(crt)
}
else if (prev.prod != crt.prod || prev.store != crt.store) {
res = Seq(prev)
prev = MyResult(crt)
}
else {
prev = MyResult.aggregate(prev, crt)
}
if (!iter.hasNext) {
res = res ++ Seq(prev)
}
res
}
}
(3)运行聚合:
:paste
val sales = sc.parallelize(
List(MyRec("West", "Apple", 2.0, 10),
MyRec("West", "Apple", 3.0, 15),
MyRec("West", "Orange", 5.0, 15),
MyRec("South", "Orange", 3.0, 9),
MyRec("South", "Orange", 6.0, 18),
MyRec("East", "Milk", 5.0, 5),
MyRec("West", "Apple", 7.0, 11)), 2).toDS
sales.mapPartitions(iter => Iterator(iter.toList)).show(false)
val result = sales
.mapPartitions(recIter => pipelinedAggregator(recIter))
.flatMap(identity)
result.show
result.explain
输出:
+-------------------------------------------------------------------------------------+
|value |
+-------------------------------------------------------------------------------------+
|[[West,Apple,2.0,10], [West,Apple,3.0,15], [West,Orange,5.0,15]] |
|[[South,Orange,3.0,9], [South,Orange,6.0,18], [East,Milk,5.0,5], [West,Apple,7.0,11]]|
+-------------------------------------------------------------------------------------+
+-----+------+---------+-------+-------+-----------+
|store| prod|total_amt|min_amt|max_amt|total_units|
+-----+------+---------+-------+-------+-----------+
| West| Apple| 5.0| 2.0| 3.0| 25|
| West|Orange| 5.0| 5.0| 5.0| 15|
|South|Orange| 9.0| 3.0| 6.0| 27|
| East| Milk| 5.0| 5.0| 5.0| 5|
| West| Apple| 7.0| 7.0| 7.0| 11|
+-----+------+---------+-------+-------+-----------+
== Physical Plan ==
*SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).store, true) AS store#31, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).prod, true) AS prod#32, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).total_amt AS total_amt#33, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).min_amt AS min_amt#34, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).max_amt AS max_amt#35, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).total_units AS total_units#36]
+- MapPartitions <function1>, obj#30: $line14.$read$$iw$$iw$MyResult
+- MapPartitions <function1>, obj#20: scala.collection.Seq
+- Scan ExternalRDDScan[obj#4]
sales: org.apache.spark.sql.Dataset[MyRec] = [store: string, prod: string ... 2 more fields]
result: org.apache.spark.sql.Dataset[MyResult] = [store: string, prod: string ... 4 more fields]
答案 1 :(得分:0)
如果这是您想要的输出
+-----+------+--------+----------+
|store|prod |max(amt)|avg(units)|
+-----+------+--------+----------+
|South|Orange|6.0 |13.5 |
|West |Orange|5.0 |15.0 |
|East |Milk |5.0 |5.0 |
|West |Apple |3.0 |12.5 |
+-----+------+--------+----------+
Spark Dataframe具有通用简明速记语法所要求的所有功能
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
object TestJob2 {
def main (args: Array[String]): Unit = {
val sparkSession = SparkSession
.builder()
.appName(this.getClass.getName.replace("$", ""))
.master("local")
.getOrCreate()
val sc = sparkSession.sparkContext
import sparkSession.sqlContext.implicits._
val rawDf = Seq(
("West", "Apple", 2.0, 10),
("West", "Apple", 3.0, 15),
("West", "Orange", 5.0, 15),
("South", "Orange", 3.0, 9),
("South", "Orange", 6.0, 18),
("East", "Milk", 5.0, 5)
).toDF("store", "prod", "amt", "units")
rawDf.show(false)
rawDf.printSchema
val aggDf = rawDf
.groupBy("store", "prod")
.agg(
max(col("amt")),
avg(col("units"))
// in case you need to retain more info
// , collect_list(struct("*")).as("horizontal")
)
aggDf.printSchema
aggDf.show(false)
}
}
取消注释collect_list行以汇总所有内容
+-----+------+--------+----------+---------------------------------------------------+
|store|prod |max(amt)|avg(units)|horizontal
|
+-----+------+--------+----------+---------------------------------------------------+
|South|Orange|6.0 |13.5 |[[South, Orange, 3.0, 9], [South, Orange, 6.0, 18]]|
|West |Orange|5.0 |15.0 |[[West, Orange, 5.0, 15]]
|
|East |Milk |5.0 |5.0 |[[East, Milk, 5.0, 5]]
|
|West |Apple |3.0 |12.5 |[[West, Apple, 2.0, 10], [West, Apple, 3.0, 15]] |
+-----+------+--------+----------+---------------------------------------------------+
答案 2 :(得分:0)
您指定的最大和平均聚合在多行上。
如果要保留所有原始行,请使用将进行分区的Window函数。
如果要减少每个分区中的行,则必须指定一个减少逻辑或过滤器。
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
object TestJob7 {
def main (args: Array[String]): Unit = {
val sparkSession = SparkSession
.builder()
.appName(this.getClass.getName.replace("$", ""))
.master("local")
.getOrCreate()
val sc = sparkSession.sparkContext
sc.setLogLevel("ERROR")
import sparkSession.sqlContext.implicits._
val rawDf = Seq(
("West", "Apple", 2.0, 10),
("West", "Apple", 3.0, 15),
("West", "Orange", 5.0, 15),
("South", "Orange", 3.0, 9),
("South", "Orange", 6.0, 18),
("East", "Milk", 5.0, 5)
).toDF("store", "prod", "amt", "units")
rawDf.show(false)
rawDf.printSchema
val storeProdWindow = Window
.partitionBy("store", "prod")
val aggDf = rawDf
.withColumn("max(amt)", max("amt").over(storeProdWindow))
.withColumn("avg(units)", avg("units").over(storeProdWindow))
aggDf.printSchema
aggDf.show(false)
}
}
这里是结果,请注意它已经被分组(窗口随机分成多个分区)
+-----+------+---+-----+--------+----------+
|store|prod |amt|units|max(amt)|avg(units)|
+-----+------+---+-----+--------+----------+
|South|Orange|3.0|9 |6.0 |13.5 |
|South|Orange|6.0|18 |6.0 |13.5 |
|West |Orange|5.0|15 |5.0 |15.0 |
|East |Milk |5.0|5 |5.0 |5.0 |
|West |Apple |2.0|10 |3.0 |12.5 |
|West |Apple |3.0|15 |3.0 |12.5 |
+-----+------+---+-----+--------+----------+
答案 3 :(得分:0)
聚合函数会减少组中指定列的行值。可以使用Dataframe功能,一次迭代地执行多个不同的聚合,从而生成具有来自输入行的值的新列。如果希望保留其他行值,则需要实现归约逻辑,该逻辑指定每个值所来自的行。例如,将第一行的所有值都保留为age的最大值。为此,您可以使用UDAF(用户定义的聚合函数)来减少组中的行。在该示例中,我还使用同一聚合中的标准聚合函数来聚合最大amt和平均单位。
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
object ReduceAggJob {
def main (args: Array[String]): Unit = {
val appName = this.getClass.getName.replace("$", "")
println(s"appName: $appName")
val sparkSession = SparkSession
.builder()
.appName(appName)
.master("local")
.getOrCreate()
val sc = sparkSession.sparkContext
sc.setLogLevel("ERROR")
import sparkSession.sqlContext.implicits._
val rawDf = Seq(
("West", "Apple", 2.0, 10),
("West", "Apple", 3.0, 15),
("West", "Orange", 5.0, 15),
("West", "Orange", 17.0, 15),
("South", "Orange", 3.0, 9),
("South", "Orange", 6.0, 18),
("East", "Milk", 5.0, 5)
).toDF("store", "prod", "amt", "units")
rawDf.printSchema
rawDf.show(false)
// Create an instance of UDAF GeometricMean.
val maxAmtUdaf = new KeepRowWithMaxAmt
// Keep the row with max amt
val aggDf = rawDf
.groupBy("store", "prod")
.agg(
max("amt"),
avg("units"),
maxAmtUdaf(
col("store"),
col("prod"),
col("amt"),
col("units")).as("KeepRowWithMaxAmt")
)
aggDf.printSchema
aggDf.show(false)
}
}
UDAF
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
class KeepRowWithMaxAmt extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(
StructField("store", StringType) ::
StructField("prod", StringType) ::
StructField("amt", DoubleType) ::
StructField("units", IntegerType) :: Nil
)
// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(
StructField("store", StringType) ::
StructField("prod", StringType) ::
StructField("amt", DoubleType) ::
StructField("units", IntegerType) :: Nil
)
// This is the output type of your aggregation function.
override def dataType: DataType =
StructType((Array(
StructField("store", StringType),
StructField("prod", StringType),
StructField("amt", DoubleType),
StructField("units", IntegerType)
)))
override def deterministic: Boolean = true
// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = ""
buffer(1) = ""
buffer(2) = 0.0
buffer(3) = 0
}
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val amt = buffer.getAs[Double](2)
val candidateAmt = input.getAs[Double](2)
amt match {
case a if a < candidateAmt =>
buffer(0) = input.getAs[String](0)
buffer(1) = input.getAs[String](1)
buffer(2) = input.getAs[Double](2)
buffer(3) = input.getAs[Int](3)
case _ =>
}
}
// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer2.getAs[String](0)
buffer1(1) = buffer2.getAs[String](1)
buffer1(2) = buffer2.getAs[Double](2)
buffer1(3) = buffer2.getAs[Int](3)
}
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
buffer
}
}