Spark数据帧 - 按数组中的位置对字符串进行分组

时间:2018-03-19 20:53:14

标签: scala apache-spark

在scala上使用Spark 1.6,如何按键按位置对代码列中的每个字符进行分组?第一个串起来,第二个字符在一起......等等

val someDF = Seq(
    (123, "0000"), 
    (123, "X000"),
    (123, "C111"),
    (124, "0000"),
    (124, "0000"),
    (124, "C200")).toDF("key", "code")

someDF.show()
+---+----+
|key|code|
+---+----+
|123|0000|
|123|X000|
|123|C111|
|124|0000|
|124|0000|
|124|C200|
+---+----+



val df = someDF.select($"key", split($"code","").as("code_split"))

df.show()
+---+--------------+
|key|    code_split|
+---+--------------+
|123|[0, 0, 0, 0, ]|
|123|[X, 0, 0, 0, ]|
|123|[C, 1, 1, 1, ]|
|124|[0, 0, 0, 0, ]|
|124|[0, 0, 0, 0, ]|
|124|[C, 2, 0, 0, ]|
+---+--------------+

使用collect_list,我可以一次为一列做到这一点。如何在没有循环的情况下为所有组合执行此操作?

df
.select($"id",
    $"code_split"(0).as("m1"),  
    $"code_split"(1).as("m2"),
    $"code_split"(2).as("m3"),
    $"code_split"(3).as("m4")
    )
.groupBy($"id").agg(
    collect_list($"m1"), 
    collect_list($"m2"),
    collect_list($"m3"),
    collect_list($"m4")
    )
.show()

+---+----------------+----------------+----------------+----------------+
| id|collect_list(m1)|collect_list(m2)|collect_list(m3)|collect_list(m4)|
+---+----------------+----------------+----------------+----------------+
|123|       [0, X, C]|       [0, 0, 1]|       [0, 0, 1]|       [0, 0, 1]|
|124|       [0, 0, C]|       [0, 0, 1]|       [0, 0, 0]|       [0, 0, 0]|
+---+----------------+----------------+----------------+----------------+

有没有办法在不重复agg的collect_list的情况下获得相同的结果?如果我有60个实例,我不想复制粘贴60次。

2 个答案:

答案 0 :(得分:1)

我认为必须分割code列才能实现结果,而是分成每个字符的列,而不是数组。这将有助于进一步分组角色。

可以通过以下方式完成此分割:

import org.apache.spark.sql.functions._

val originalDf: DataFrame = ...

// split function: returns a new dataframe with column "code{i}"
// containing the character at index "i" from "code" column 
private def splitCodeColumn(df: DataFrame, i: Int): DataFrame = {
  df.withColumn("code" + i, substring(originalDf("code"), i, 1))
}

// number of columns to split code in
val nbSplitColumns = "0000".length

val codeColumnSplitDf = 
  (1 to nbSplitColumns).foldLeft(originalDf){ case(df, i) => splitCodeColumn(df, i)}.drop("code")

// register it in order to use with Spark SQL
val splitTempViewName = "code_split"
codeColumnSplitDf.registerTempTable(splitTempViewName)

现在codeColumnSplitDf包含:

+---+-----+-----+-----+-----+
|key|code1|code2|code3|code4|
+---+-----+-----+-----+-----+
|123|    0|    0|    0|    0|
|123|    X|    0|    0|    0|
|123|    C|    1|    1|    1|
|124|    0|    0|    0|    0|
|124|    0|    0|    0|    0|
|124|    C|    2|    0|    0|
+---+-----+-----+-----+-----+

我们会使用collect_list函数来汇总按key分组的字符:

// collect_list calls to insert into SQL
val aggregateSelections = (1 to nbSplitColumns).map(i => s"collect_list(code$i) as code_$i").mkString(", ")

val sqlCtx: SQLContext = ...

// DataFrames with expected results
val resultDf = sqlCtx.sql(s"SELECT key, $aggregateSelections FROM $splitTempViewName GROUP BY key")

resultDf包含:

+---+---------+---------+---------+---------+
|key|   code_1|   code_2|   code_3|   code_4|
+---+---------+---------+---------+---------+
|123|[0, X, C]|[0, 0, 1]|[0, 0, 1]|[0, 0, 1]|
|124|[0, 0, C]|[0, 0, 2]|[0, 0, 0]|[0, 0, 0]|
+---+---------+---------+---------+---------+

更新

避免重复selectagg中的元素:

val codeSplitColumns =
  Seq(col("id")) ++ (0 until nbSplitColumns).map(i => col("code_split")(i).as("m" + i))

val aggregations =
  (0 until nbSplitColumns).map(i => collect_list(col("m" + i)))

df.select(codeSplitColumns:_*)
  .groupBy(col("id"))
  .agg(aggregations.head, aggregations.tail:_*)

答案 1 :(得分:0)

  
    

有没有办法在不重复agg的collect_list的情况下获得相同的结果?

  

是的,肯定有多种方法可以使用 collect_list一次。我将使用udf函数

向您展示

udf功能

def combineUdf = udf((strs: Seq[String])=> {
  val length = strs(0).length
  val groupLength = strs.length
  val result = for(i <- 0 until length; arr <- strs)yield arr(i).toString
  result.grouped(groupLength).map(_.toArray).toArray
})

combineUdf函数获取收​​集的字符串列表(因为您将只使用一个collect_list函数),然后解析数组Array格式Array[Array[String]]

为您提供所需的输出

只需将udfselect所有必要的列称为

即可
someDF.groupBy("key").agg(combineUdf(collect_list("code")).as("value"))
    .select(col("key") +: (0 to 3).map(x => col("value")(x).as("value_"+x)): _*)
    .show(false)

其中(0 to 3)可根据您的最终列必要性

而变化

它应该为您提供您想要的输出

+---+---------+---------+---------+---------+
|key|value_0  |value_1  |value_2  |value_3  |
+---+---------+---------+---------+---------+
|123|[0, X, C]|[0, 0, 1]|[0, 0, 1]|[0, 0, 1]|
|124|[0, 0, C]|[0, 0, 2]|[0, 0, 0]|[0, 0, 0]|
+---+---------+---------+---------+---------+

我希望答案很有帮助