如何在Spark中使用scala语言进行聚合而不会爆炸

时间:2018-04-02 07:18:43

标签: scala apache-spark apache-spark-sql spark-dataframe

我使用Spark 2.2版本和Scala作为编程语言。

输入数据:

{"amount":"2.00","cal_group":[{}],"set_id":7057} {"amount":"1.00","cal_group":[{}],"set_id":7057} {"amount":"7.00","cal_group": [{"abc_cd":"abc00160","abc_cnt":6.0,"cde_cnt":7.0},{"abc_cd":"abc00160","abc_cnt":5.0,"cde_cnt":2.0},{"abc_cd":"abc00249","abc_cnt":0.0,"cde_cnt":1.0}],"set_id":7057}

输入数据框:

[2.00,WrappedArray([null,null,null]),7057]
[1.00,WrappedArray([null,null,null]),7057]
[7.00,WrappedArray([abc00160,6.0,7.0],[abc00160,5.0,2.0,],[abc00249,0.0,1.0]),7057]

输入数据架构:

|-- amount: string (nullable = true)
|-- cal_group: array (nullable = true)
|    |-- element: struct (containsNull = true)
|    |    |-- abc_cd: string (nullable = true)
|    |    |-- abc_cnt: double (nullable = true)
|    |    |-- cde_cnt: double (nullable = true)
|--set_id: double

注意 :每个包装数组都是一个包含 abc_cd 和其他2个度量列的结构。 < / p>

我想对输入数据进行两级聚合。它被称为步骤1和步骤2.

第1步:

我们需要获取每个 set_id 的金额总和,并在为 cal_group 执行collect_list时删除空值

我试过以下代码:

val res1=res.groupBy($"set_id").agg(sum($"amount").as('amount_total),collect_list(struct($"cal_group")).as('finalgroup))

按预期给予我金额总和。 但在这里我不知道如何跳过null WrappedArray cal_group 列。

输出:第1步

[7057,10.00,WrappedArray([WrappedArray([null,null,null])],[WrappedArray([null,null,null])],[WrappedArray([null,null,null])],[WrappedArray([abc00160,6.0,7.0],[abc00160,5.0,2.0],[abc00249,0.0,1.0])])

第2步:

然后我想要 abc_cd 代码级别的聚合2度量( abc_cnt,cde_cnt )。

这里的聚合可以通过cal_group列上的explode函数来完成。它将在行级转换cal_group记录,它将增加数据的行/数量。

所以,我尝试爆炸结构并在 abc_cd 上进行分组。

示例代码,如果使用explode函数进行求和:

   val res2 = res1.select($"set_id",$"amount_total",explode($"cal_group").as("cal_group"))
    val res1 = res2.select($"set_id",$"amount_total",$"cal_group")
                         .groupBy($"set_id",$"cal_group.abc_cd")
                         .agg(sum($"cal_group.abc_cnt").as('abc_cnt_sum),
                              sum($"cal_group.cde_cnt").as('cde_cnt_sum),
                              )

所以在这里,我不想爆炸col_group列。 因为它正在增加音量。

在第2步之后预期的输出:

[7057,10.00,WrappedArray(**[WrappedArray([null,null,null])],
                                       [WrappedArray([null,null,null])],
                                       [WrappedArray([null,null,null])],
                                       [WrappedArray([abc00160,11.0,9.0],
                                                     [abc00249,0.0,1.0])])

是否有任何可用选项,其中函数应在记录级别聚合并在收集之前删除null结构。

提前致谢。

2 个答案:

答案 0 :(得分:1)

您可以为第二部分聚合定义udf函数

import org.apache.spark.sql.functions._
def aggregateUdf = udf((nestedArray: Seq[Seq[Row]])=>
  nestedArray
    .flatMap(x => x
      .map(y => (y(0).asInstanceOf[String], (y(1).asInstanceOf[Double], y(2).asInstanceOf[Double]))))
      .filterNot(_._1 == null)
      .groupBy(_._1)
      .map(x => (x._1, x._2.map(_._2._1).sum, x._2.map(_._2._2).sum)).toArray
)

您可以在第一次聚合之后调用udf函数通过删除结构部分也需要修改

val finalRes=res
  .groupBy($"set_id")
  .agg(sum($"amount").as('amount_total),collect_list($"cal_group").as('finalgroup))
  .withColumn("finalgroup", aggregateUdf('finalgroup))

所以finalRes将是

+------+------------+-----------------------------------------+
|set_id|amount_total|finalgroup                               |
+------+------------+-----------------------------------------+
|7057  |10.0        |[[abc00249,0.0,1.0], [abc00160,11.0,9.0]]|
+------+------------+-----------------------------------------+

答案 1 :(得分:0)

我在json数据下面加载并加载以获得与您相同的模式:

{"amount":"2.00","cal_group":[{}],"set_id":7057.0}
{"amount":"1.00","cal_group":[{}],"set_id":7057}
{"amount":"7.00","cal_group": [{"abc_cd":"abc00160","abc_cnt":6.0,"cde_cnt":7.0},{"abc_cd":"abc00160","abc_cnt":5.0,"cde_cnt":2.0},{"abc_cd":"abc00249","abc_cnt":0.0,"cde_cnt":1.0}],"set_id":7057}
  

但在这里我不知道如何跳过null WrappedArray cal_group列

我认为collect_list会自动删除null,但在您的情况下它无法删除,因为您已使用struct进行不需要的聚合。因此,第1步的正确转换是:

val res1=res.groupBy($"set_id").agg(sum($"amount").as('amount_total),(collect_list($"cal_group")).as('finalgroup))

,它给出了以下输出(showprintSchema

+------+------------+--------------------------------------------------------------------------+
|set_id|amount_total|finalgroup                                                                |
+------+------------+--------------------------------------------------------------------------+
|7057.0|10.0        |[WrappedArray([abc00160,6.0,7.0], [abc00160,5.0,2.0], [abc00249,0.0,1.0])]|
+------+------------+--------------------------------------------------------------------------+
root
 |-- set_id: double (nullable = true)
 |-- amount_total: double (nullable = true)
 |-- finalgroup: array (nullable = true)
 |    |-- element: array (containsNull = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- abc_cd: string (nullable = true)
 |    |    |    |-- abc_cnt: double (nullable = true)
 |    |    |    |-- cde_cnt: double (nullable = true)

第2步

下面假设上面的代码是作为步骤1运行的。我只使用爆炸机制。

要处理您的数据结构,您必须进行两次爆炸,因为amount的{​​{1}}分组结构是双嵌套数组。下面是给出所需o / p的代码:

cal_group

带输出:

val res2 = res1.select($"set_id",$"amount_total",explode($"finalgroup").as("cal_group"))
val res3 = res2.select($"set_id",$"amount_total",explode($"cal_group").as("cal_group_exp"))
val res4 = res3.groupBy($"set_id",$"cal_group_exp.abc_cd")
                          .agg(sum($"cal_group_exp.abc_cnt").as('abc_cnt_sum),
                              sum($"cal_group_exp.cde_cnt").as('cde_cnt_sum))
res4.show(false)