分解多个嵌套列,执行agg并连接所有表

时间:2017-06-07 14:31:53

标签: scala apache-spark dataframe apache-spark-sql

我想知道是否还有其他选择更有效率,例如:

val df0 = df.select($"id", explode($"event.x0") as "n_0" ).groupBy("id").agg(sum("n_0") as "0")
val df1 = df.select($"id", explode($"event.x1") as "n_1").groupBy("id").agg(sum("n_1") as "1")
val df2 = df.select($"id", explode($"event.x2") as "n_2").groupBy("id").agg(sum("n_2") as "2")
val df3 = df.select($"id", explode($"event.x3") as "n_3").groupBy("id").agg(sum("n_3") as "3)


val final_df = df.join(df0, "id").join(df1, "id").join(df2, "id").join(df3, "id")

我正在尝试这样的事情:

val df_x = df.select($"id", $"event", explode($"event.x0") as "0" )
            .select($"id", $"event", $"0", explode($"event.x1") as "1")
            .select($"id", $"event", $"0", $"1", explode($"event.x2") as "2")
            .groupBy("id")
            .agg(sum("0") as "0", sum("1") as "1", sum("2") as "2")

val final_df = df.join(df_x, "id")

尽管它运行得更快!!!!聚合值是错误的,因此它实际上不起作用:(!

任何减少连接数量的想法?

1 个答案:

答案 0 :(得分:0)

假设每个id没有太多匹配记录,您可以使用collect_list聚合函数将所有匹配的数组收集到数组数组中,然后用户收集定义了对这些嵌套数组求和的函数:

val flattenAndSum = udf[Int, mutable.Seq[mutable.Seq[Int]]] { seqOfArrays => seqOfArrays.flatten.sum }

val sums = df.groupBy($"id").agg(
  collect_list($"event.x0") as "arr0",
  collect_list($"event.x1") as "arr1",
  collect_list($"event.x2") as "arr2",
  collect_list($"event.x3") as "arr3"
).select($"id",
  flattenAndSum($"arr0") as "0",
  flattenAndSum($"arr1") as "1",
  flattenAndSum($"arr2") as "2",
  flattenAndSum($"arr3") as "3"
)

df.join(sums, "id")

或者,如果无法做出这种假设,您可以创建用户定义的聚合函数来动态执行展平和求和。这更安全,可能更快,但需要更多工作:

// implement a UDAF:
class FlattenAndSum extends UserDefinedAggregateFunction {
  override def inputSchema: StructType = new StructType().add("arr", ArrayType(IntegerType))
  override def bufferSchema: StructType = new StructType().add("sum", IntegerType)
  override def dataType: DataType = IntegerType
  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = buffer.update(0, 0)

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val current = buffer.getAs[Int](0)
    val toAdd = input.getAs[Seq[Int]](0).sum
    buffer.update(0, current + toAdd)
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0, buffer1.getAs[Int](0) + buffer2.getAs[Int](0))
  }

  override def evaluate(buffer: Row): Any = buffer.getAs[Int](0)
}

// use it in aggregation:
val flattenAndSum = new FlattenAndSum()

val sums = df.groupBy($"id").agg(
  flattenAndSum($"event.x0") as "0",
  flattenAndSum($"event.x1") as "1",
  flattenAndSum($"event.x2") as "2",
  flattenAndSum($"event.x3") as "3"
)

df.join(sums, "id")