spark scala - 按数组列

时间:2018-06-04 00:35:55

标签: arrays scala apache-spark mapreduce

我是一个非常新的火花scala。感谢您的帮助..    我有一个数据框

val df = Seq(
  ("a", "a1", Array("x1","x2")), 
  ("a", "b1", Array("x1")),
  ("a", "c1", Array("x2")),
  ("c", "c3", Array("x2")),
  ("a", "d1", Array("x3")),
  ("a", "e1", Array("x2","x1"))
).toDF("k1", "k2", "k3")

我正在寻找一种方法,用k1和k3对它进行分组,并在数组中收集k2。    但是,k3是一个数组,我需要应用包含(而不是精确    匹配)用于分组。换句话说,我正在寻找一个结果    像这样

k1   k3       k2                count
a   (x1,x2)   (a1,b1,c1,e1)     4
a    (x3)      (d1)             1
c    (x2)      (c3)             1

有人可以建议如何实现这个目标吗?

提前致谢!

1 个答案:

答案 0 :(得分:0)

我建议您分组按k1列收集k2和k3的结构列表将收集的列表传递给udf函数用于计数当k3中的数组包含在另一个k3数组中并添加k2元素时。

然后,您可以使用double-checkexplode表达式来获得所需的输出

以下是完整的工作解决方案

select

应该给你

val df = Seq(
  ("a", "a1", Array("x1","x2")),
  ("a", "b1", Array("x1")),
  ("a", "c1", Array("x2")),
  ("c", "c3", Array("x2")),
  ("a", "d1", Array("x3")),
  ("a", "e1", Array("x2","x1"))
  ).toDF("k1", "k2", "k3")

import org.apache.spark.sql.functions._
def containsGoupingUdf = udf((arr: Seq[Row]) => {
  val firstStruct =  arr.head
  val tailStructs =  arr.tail
  var result = Array((collection.mutable.Set(firstStruct.getAs[String]("k2")), firstStruct.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, 1))
  for(str <- tailStructs){
    var added = false
    for((res, index) <- result.zipWithIndex) {
      if (str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").exists(res._2) || res._2.exists(x => str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").contains(x))) {
        result(index) = (res._1 + str.getAs[String]("k2"), res._2 ++ str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, res._3 + 1)
        added = true
      }
    }
    if(!added){
      result = result ++ Array((collection.mutable.Set(str.getAs[String]("k2")), str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, 1))
    }
  }
  result.map(tuple => (tuple._1.toArray, tuple._2.toArray, tuple._3))
})

df.groupBy("k1").agg(containsGoupingUdf(collect_list(struct(col("k2"), col("k3")))).as("aggregated"))
    .select(col("k1"), explode(col("aggregated")).as("aggregated"))
    .select(col("k1"), col("aggregated._2").as("k3"), col("aggregated._1").as("k2"), col("aggregated._3").as("count"))
  .show(false)

我希望答案很有帮助,您可以根据自己的需要进行修改。