我正在尝试根据“ id”列对以下数据集进行分组,并对元素“ values”列中的数组求和。如何使用Scala在Spark中进行操作?
输入:(2列的数据集,String类型的column1和Array [Int]类型的column2)
| id | values |
---------------
| A | [12,61,23,43]
| A | [43,11,24,45]
| B | [32,12,53,21]
| C | [11,12,13,14]
| C | [43,52,12,52]
| B | [33,21,15,24]
预期输出:(数据集或数据框)
| id | values |
---------------
| A | [55,72,47,88]
| B | [65,33,68,45]
| C | [54,64,25,66]
注意: 结果必须是灵活和动态的。也就是说,即使有1000列,或者即使文件是几个TB或PB,该解决方案也应保持良好状态。
答案 0 :(得分:0)
当您说它必须灵活时,我不确定您的意思,但是在我的头上,我可以想到几种方法。第一个(我认为是最漂亮的)使用udf
:
// Creating a small test example
val testDF = spark.sparkContext.parallelize(Seq(("a", Seq(1,2,3)), ("a", Seq(4,5,6)), ("b", Seq(1,3,4)))).toDF("id", "arr")
val sum_arr = udf((list: Seq[Seq[Int]]) => list.transpose.map(arr => arr.sum))
testDF
.groupBy('id)
.agg(sum_arr(collect_list('arr)) as "summed_values")
但是,如果您有数十亿个相同的ID,那么collect_list
当然是一个问题。在这种情况下,您可以执行以下操作:
testDF
.flatMap{case Row(id: String, list: Seq[Int]) => list.indices.map(index => (id, index, list(index)))}
.toDF("id", "arr_index", "arr_element")
.groupBy('id, 'arr_index)
.agg(sum("arr_element") as "sum")
.groupBy('id)
.agg(collect_list('sum) as "summed_values")
答案 1 :(得分:0)
下面的单行解决方案对我有用。
ds.groupBy("Country").agg(array((0 until n).map(i => sum(col("Values").getItem(i))) :_* ) as "Values")