在spark-sql中,我有一个带有列col
的数据框,其中包含大小为100的Int数组(例如)。
我想将此列聚合为单个值,该值是Int大小为100的数组,其中包含该列每个元素的总和。 可以通过调用以下方法来实现:
dataframe.agg(functions.array((0 until 100).map(i => functions.sum(i)) : _*))
这将生成代码以明确进行100次聚合,然后将100个结果显示为100个项目的数组。但是,这似乎效率很低,因为如果我的数组大小超过〜1000个项目,催化剂甚至将无法为此生成代码。
spark-sql中是否有一种构造可以更有效地执行此操作?
理想情况下,应该可以在数组上自动传播sum
聚合以进行成员求和,但是我在文档中没有发现与此相关的任何内容。
我的代码有哪些替代方案?
编辑:我的回溯:
ERROR codegen.CodeGenerator: failed to compile: org.codehaus.janino.InternalCompilerException: Compiling "GeneratedClass": Code of method "processNext()V" of class "org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator" grows beyond 64 KB
org.codehaus.janino.InternalCompilerException: Compiling "GeneratedClass": Code of method "processNext()V" of class "org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator" grows beyond 64 KB
at org.codehaus.janino.UnitCompiler.compileUnit(UnitCompiler.java:361)
at org.codehaus.janino.SimpleCompiler.cook(SimpleCompiler.java:234)
at org.codehaus.janino.SimpleCompiler.compileToClassLoader(SimpleCompiler.java:446)
at org.codehaus.janino.ClassBodyEvaluator.compileToClass(ClassBodyEvaluator.java:313)
at org.codehaus.janino.ClassBodyEvaluator.cook(ClassBodyEvaluator.java:235)
at org.codehaus.janino.SimpleCompiler.cook(SimpleCompiler.java:204)
at org.codehaus.commons.compiler.Cookable.cook(Cookable.java:80)
at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$.org$apache$spark$sql$catalyst$expressions$codegen$CodeGenerator$$doCompile(CodeGenerator.scala:1002)
at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$$anon$1.load(CodeGenerator.scala:1069)
at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$$anon$1.load(CodeGenerator.scala:1066)
at org.spark_project.guava.cache.LocalCache$LoadingValueReference.loadFuture(LocalCache.java:3599)
at org.spark_project.guava.cache.LocalCache$Segment.loadSync(LocalCache.java:2379)
at org.spark_project.guava.cache.LocalCache$Segment.lockedGetOrLoad(LocalCache.java:2342)
at org.spark_project.guava.cache.LocalCache$Segment.get(LocalCache.java:2257)
at org.spark_project.guava.cache.LocalCache.get(LocalCache.java:4000)
at org.spark_project.guava.cache.LocalCache.getOrLoad(LocalCache.java:4004)
at org.spark_project.guava.cache.LocalCache$LocalLoadingCache.get(LocalCache.java:4874)
at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$.compile(CodeGenerator.scala:948)
at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:375)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:97)
at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:92)
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
at org.apache.spark.sql.execution.aggregate.HashAggregateExec.doExecute(HashAggregateExec.scala:92)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:97)
at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:92)
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
at org.apache.spark.sql.execution.aggregate.HashAggregateExec.doExecute(HashAggregateExec.scala:92)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
at org.apache.spark.sql.execution.exchange.ShuffleExchange.prepareShuffleDependency(ShuffleExchange.scala:88)
at org.apache.spark.sql.execution.exchange.ShuffleExchange$$anonfun$doExecute$1.apply(ShuffleExchange.scala:124)
at org.apache.spark.sql.execution.exchange.ShuffleExchange$$anonfun$doExecute$1.apply(ShuffleExchange.scala:115)
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
at org.apache.spark.sql.execution.exchange.ShuffleExchange.doExecute(ShuffleExchange.scala:115)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:92)
at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:92)
at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply$mcV$sp(FileFormatWriter.scala:173)
at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:166)
at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:166)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:65)
at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:166)
at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:145)
at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:58)
at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:56)
at org.apache.spark.sql.execution.command.ExecutedCommandExec.doExecute(commands.scala:74)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:92)
at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:92)
at org.apache.spark.sql.execution.datasources.DataSource.writeInFileFormat(DataSource.scala:435)
at org.apache.spark.sql.execution.datasources.DataSource.write(DataSource.scala:471)
at org.apache.spark.sql.execution.datasources.SaveIntoDataSourceCommand.run(SaveIntoDataSourceCommand.scala:48)
at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:58)
at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:56)
at org.apache.spark.sql.execution.command.ExecutedCommandExec.doExecute(commands.scala:74)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:92)
at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:92)
at org.apache.spark.sql.DataFrameWriter.runCommand(DataFrameWriter.scala:609)
at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:233)
at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:217)
at org.apache.spark.sql.DataFrameWriter.csv(DataFrameWriter.scala:597)
at com.criteo.enterprise.eligibility_metrics.RankingMetricsComputer$.runAndSaveMetrics(RankingMetricsComputer.scala:286)
at com.criteo.enterprise.eligibility_metrics.RankingMetricsComputer$.main(RankingMetricsComputer.scala:366)
at com.criteo.enterprise.eligibility_metrics.RankingMetricsComputer.main(RankingMetricsComputer.scala)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at org.apache.spark.deploy.yarn.ApplicationMaster$$anon$2.run(ApplicationMaster.scala:635)
答案 0 :(得分:1)
最好的方法是将嵌套数组转换为自己的行,以便可以使用单个groupBy
。这样,您可以一次聚合完成全部工作,而不是100次(或更多)。做到这一点的关键是使用posexplode
,它将把数组中的每个条目转换成一个新的Row,并带有它在数组中的索引。
例如:
import org.apache.spark.sql.functions.{posexplode, collect_list}
val data = Seq(
(Seq(1, 2, 3, 4, 5)),
(Seq(2, 3, 4, 5, 6)),
(Seq(3, 4, 5, 6, 7))
)
val df = data.toDF
val df2 = df.
select(posexplode($"value")).
groupBy($"pos").
agg(sum($"col") as "sum")
// At this point you will have rows with the index and the sum
df2.orderBy($"pos".asc).show
将输出这样的DataFrame:
+---+---+
|pos|sum|
+---+---+
| 0| 6|
| 1| 9|
| 2| 12|
| 3| 15|
| 4| 18|
+---+---+
或者,如果您希望将它们排成一排,则可以投放以下广告:
df2.groupBy().agg(collect_list(struct($"pos", $"sum")) as "list").show
不会对Array列中的值进行排序,但是您可以编写UDF以按pos字段对其进行排序,如果需要,可以删除pos字段。
每个评论已更新
如果以上方法不适用于您尝试执行的其他任何聚合,那么您将需要定义自己的UDAF。这里的基本思想是告诉Spark如何在分区中合并同一键的值以创建中间值,然后如何在分区中合并这些中间值以为每个键创建最终值。定义UDAF类后,您就可以在aggs
调用中将该类与您要执行的任何其他聚合一起使用。
这是我淘汰的一个简单例子。请注意,它假定数组长度,并且可能应该使其更具防错能力,但应该可以帮助您解决问题。
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
class ArrayCombine extends UserDefinedAggregateFunction {
// The input this aggregation will receive (each row)
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("value", ArrayType(IntegerType)) :: Nil)
// Your intermediate state as you are updating with data from each row
override def bufferSchema: StructType = StructType(
StructType(StructField("value", ArrayType(IntegerType)) :: Nil)
)
// This is the output type of your aggregatation function.
override def dataType: DataType = ArrayType(IntegerType)
override def deterministic: Boolean = true
// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = (0 until 100).toArray
}
// Given a new input row, update our state
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val sums = buffer.getSeq[Int](0)
val newVals = input.getSeq[Int](0)
buffer(0) = sums.zip(newVals).map { case (a, b) => a + b }
}
// After we have finished computing intermediate values for each partition, combine the partitions
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val sums1 = buffer1.getSeq[Int](0)
val sums2 = buffer2.getSeq[Int](0)
buffer1(0) = sums1.zip(sums2).map { case (a, b) => a + b }
}
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
buffer.getSeq[Int](0)
}
}
然后这样称呼它:
val arrayUdaf = new ArrayCombine()
df.groupBy().agg(arrayUdaf($"value")).show