如何在Spark SQL中对数组进行成员操作?

时间:2018-10-16 19:39:13

标签: apache-spark apache-spark-sql

在spark-sql中,我有一个带有列col的数据框,其中包含大小为100的Int数组(例如)。

我想将此列聚合为单个值,该值是Int大小为100的数组,其中包含该列每个元素的总和。 可以通过调用以下方法来实现:

dataframe.agg(functions.array((0 until 100).map(i => functions.sum(i)) : _*))

这将生成代码以明确进行100次聚合,然后将100个结果显示为100个项目的数组。但是,这似乎效率很低,因为如果我的数组大小超过〜1000个项目,催化剂甚至将无法为此生成代码。 spark-sql中是否有一种构造可以更有效地执行此操作? 理想情况下,应该可以在数组上自动传播sum聚合以进行成员求和,但是我在文档中没有发现与此相关的任何内容。 我的代码有哪些替代方案?

编辑:我的回溯:

   ERROR codegen.CodeGenerator: failed to compile: org.codehaus.janino.InternalCompilerException: Compiling "GeneratedClass": Code of method "processNext()V" of class "org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator" grows beyond 64 KB
org.codehaus.janino.InternalCompilerException: Compiling "GeneratedClass": Code of method "processNext()V" of class "org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator" grows beyond 64 KB
    at org.codehaus.janino.UnitCompiler.compileUnit(UnitCompiler.java:361)
    at org.codehaus.janino.SimpleCompiler.cook(SimpleCompiler.java:234)
    at org.codehaus.janino.SimpleCompiler.compileToClassLoader(SimpleCompiler.java:446)
    at org.codehaus.janino.ClassBodyEvaluator.compileToClass(ClassBodyEvaluator.java:313)
    at org.codehaus.janino.ClassBodyEvaluator.cook(ClassBodyEvaluator.java:235)
    at org.codehaus.janino.SimpleCompiler.cook(SimpleCompiler.java:204)
    at org.codehaus.commons.compiler.Cookable.cook(Cookable.java:80)
    at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$.org$apache$spark$sql$catalyst$expressions$codegen$CodeGenerator$$doCompile(CodeGenerator.scala:1002)
    at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$$anon$1.load(CodeGenerator.scala:1069)
    at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$$anon$1.load(CodeGenerator.scala:1066)
    at org.spark_project.guava.cache.LocalCache$LoadingValueReference.loadFuture(LocalCache.java:3599)
    at org.spark_project.guava.cache.LocalCache$Segment.loadSync(LocalCache.java:2379)
    at org.spark_project.guava.cache.LocalCache$Segment.lockedGetOrLoad(LocalCache.java:2342)
    at org.spark_project.guava.cache.LocalCache$Segment.get(LocalCache.java:2257)
    at org.spark_project.guava.cache.LocalCache.get(LocalCache.java:4000)
    at org.spark_project.guava.cache.LocalCache.getOrLoad(LocalCache.java:4004)
    at org.spark_project.guava.cache.LocalCache$LocalLoadingCache.get(LocalCache.java:4874)
    at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$.compile(CodeGenerator.scala:948)
    at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:375)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:97)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:92)
    at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec.doExecute(HashAggregateExec.scala:92)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:97)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:92)
    at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec.doExecute(HashAggregateExec.scala:92)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.exchange.ShuffleExchange.prepareShuffleDependency(ShuffleExchange.scala:88)
    at org.apache.spark.sql.execution.exchange.ShuffleExchange$$anonfun$doExecute$1.apply(ShuffleExchange.scala:124)
    at org.apache.spark.sql.execution.exchange.ShuffleExchange$$anonfun$doExecute$1.apply(ShuffleExchange.scala:115)
    at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
    at org.apache.spark.sql.execution.exchange.ShuffleExchange.doExecute(ShuffleExchange.scala:115)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:92)
    at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:92)
    at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply$mcV$sp(FileFormatWriter.scala:173)
    at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:166)
    at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:166)
    at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:65)
    at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:166)
    at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:145)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:58)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:56)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.doExecute(commands.scala:74)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:92)
    at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:92)
    at org.apache.spark.sql.execution.datasources.DataSource.writeInFileFormat(DataSource.scala:435)
    at org.apache.spark.sql.execution.datasources.DataSource.write(DataSource.scala:471)
    at org.apache.spark.sql.execution.datasources.SaveIntoDataSourceCommand.run(SaveIntoDataSourceCommand.scala:48)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:58)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:56)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.doExecute(commands.scala:74)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:92)
    at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:92)
    at org.apache.spark.sql.DataFrameWriter.runCommand(DataFrameWriter.scala:609)
    at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:233)
    at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:217)
    at org.apache.spark.sql.DataFrameWriter.csv(DataFrameWriter.scala:597)
    at com.criteo.enterprise.eligibility_metrics.RankingMetricsComputer$.runAndSaveMetrics(RankingMetricsComputer.scala:286)
    at com.criteo.enterprise.eligibility_metrics.RankingMetricsComputer$.main(RankingMetricsComputer.scala:366)
    at com.criteo.enterprise.eligibility_metrics.RankingMetricsComputer.main(RankingMetricsComputer.scala)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at org.apache.spark.deploy.yarn.ApplicationMaster$$anon$2.run(ApplicationMaster.scala:635)

1 个答案:

答案 0 :(得分:1)

最好的方法是将嵌套数组转换为自己的行,以便可以使用单个groupBy。这样,您可以一次聚合完成全部工作,而不是100次(或更多)。做到这一点的关键是使用posexplode,它将把数组中的每个条目转换成一个新的Row,并带有它在数组中的索引。

例如:

import org.apache.spark.sql.functions.{posexplode, collect_list}

val data = Seq(
    (Seq(1, 2, 3, 4, 5)),
    (Seq(2, 3, 4, 5, 6)),
    (Seq(3, 4, 5, 6, 7))
)

val df = data.toDF

val df2 = df.
    select(posexplode($"value")).
    groupBy($"pos").
    agg(sum($"col") as "sum")

// At this point you will have rows with the index and the sum
df2.orderBy($"pos".asc).show

将输出这样的DataFrame:

+---+---+
|pos|sum|
+---+---+
|  0|  6|
|  1|  9|
|  2| 12|
|  3| 15|
|  4| 18|
+---+---+

或者,如果您希望将它们排成一排,则可以投放以下广告:

df2.groupBy().agg(collect_list(struct($"pos", $"sum")) as "list").show

不会对Array列中的值进行排序,但是您可以编写UDF以按pos字段对其进行排序,如果需要,可以删除pos字段。

每个评论已更新

如果以上方法不适用于您尝试执行的其他任何聚合,那么您将需要定义自己的UDAF。这里的基本思想是告诉Spark如何在分区中合并同一键的值以创建中间值,然后如何在分区中合并这些中间值以为每个键创建最终值。定义UDAF类后,您就可以在aggs调用中将该类与您要执行的任何其他聚合一起使用。

这是我淘汰的一个简单例子。请注意,它假定数组长度,并且可能应该使其更具防错能力,但应该可以帮助您解决问题。

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction


class ArrayCombine extends UserDefinedAggregateFunction {
  // The input this aggregation will receive (each row)
  override def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(StructField("value", ArrayType(IntegerType)) :: Nil)

  // Your intermediate state as you are updating with data from each row
  override def bufferSchema: StructType = StructType(
    StructType(StructField("value", ArrayType(IntegerType)) :: Nil)
  )

  // This is the output type of your aggregatation function.
  override def dataType: DataType = ArrayType(IntegerType)

  override def deterministic: Boolean = true

  // This is the initial value for your buffer schema.
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = (0 until 100).toArray
  }

  // Given a new input row, update our state
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val sums = buffer.getSeq[Int](0)
    val newVals = input.getSeq[Int](0)

    buffer(0) = sums.zip(newVals).map { case (a, b) => a + b }
  }

  // After we have finished computing intermediate values for each partition, combine the partitions
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val sums1 = buffer1.getSeq[Int](0)
    val sums2 = buffer2.getSeq[Int](0)

    buffer1(0) = sums1.zip(sums2).map { case (a, b) => a + b }
  }

  // This is where you output the final value, given the final value of your bufferSchema.
  override def evaluate(buffer: Row): Any = {
    buffer.getSeq[Int](0)
  }
}

然后这样称呼它:

val arrayUdaf = new ArrayCombine()
df.groupBy().agg(arrayUdaf($"value")).show