如何在Spark数据帧中合并数组列

时间:2016-12-28 22:15:41

标签: scala apache-spark spark-dataframe

假设我有以下数据框:

id | myStruct
__________________________________
1  | [[val1, val2], [val5, val6]]
__________________________________
2  | [[val3, val4]]

我想将所有共享相同id的myStructs分组到myStructs的数组列中。所以,上面的数据框应该成为

LIMSEQ_rabs_realpow_zero2

我知道有一个数组函数,但只将每列转换为大小为1的数组。如何合并生成的数组?

我在Scala shell中使用Spark 1.5.2。

鉴于我使用的是Spark 1.5.2,我不能使用collect_list或collect_set。

2 个答案:

答案 0 :(得分:3)

如果您使用Spark 1.5而无法升级最简单的选项是RDD.groupByKey

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

val rows = df.rdd
  .map { case Row(id, myStruct) => (id, myStruct) }
  .groupByKey
  .map { case (id, myStructs) => Row(id, myStructs) }

val schema =  StructType(Seq(
  df.schema("id"),
  StructField("myStructs", ArrayType(df.schema("myStruct").dataType))
))

sqlContext.createDataFrame(rows, schema)

可以通过像这样转换为“对”来推广它:

import org.apache.spark.sql.functions.struct

df.select(
  struct($"key1", $"key2", ..., $"keyn").alias("id"),
  struct($"val1", $"val2", ..., $"valn").alias("myStruct")
)

答案 1 :(得分:1)

在Spark 2.0+中,您可以使用collect_list来完成此任务:

scala> val df = sc.parallelize(Seq((1, ("v1", "v2")), (2, ("v3", "v4")), (1, ("v5", "v6")))).toDF("id", "myStruct")
df: org.apache.spark.sql.DataFrame = [id: int, myStruct: struct<_1: string, _2: string>]

scala> df.show
+---+--------+
| id|myStruct|
+---+--------+
|  1| [v1,v2]|
|  2| [v3,v4]|
|  1| [v5,v6]|
+---+--------+

scala> df.groupBy("id").agg(collect_list($"myStruct")).show
+---+----------------------+                                                    
| id|collect_list(myStruct)|
+---+----------------------+
|  1|    [[v1,v2], [v5,v6]]|
|  2|             [[v3,v4]]|
+---+----------------------+

但是在Spark 1.5.2中你需要这样的东西:

scala> val df2 = df.select($"id", $"myStruct._1".as("p1"), $"myStruct._2".as("p2"))
df2: org.apache.spark.sql.DataFrame = [id: int, p1: string, p2: string]

scala> df2.show
+---+---+---+
| id| p1| p2|
+---+---+---+
|  1| v1| v2|
|  2| v3| v4|
|  1| v5| v6|
+---+---+---+

scala> val rdd = df2.rdd.map{case Row(id: Int, p1: String, p2: String) => (id, (p1, p2))}
rdd: org.apache.spark.rdd.RDD[(Int, (String, String))] = MapPartitionsRDD[47] at map at <console>:32

scala> val finalDF = rdd.groupByKey.map(x => (x._1, x._2.toList)).toDF("id", "structs")
finalDF: org.apache.spark.sql.DataFrame = [id: int, structs: array<struct<_1:string,_2:string>>]

scala> finalDF.show
+---+------------------+
| id|           structs|
+---+------------------+
|  1|[[v1,v2], [v5,v6]]|
|  2|         [[v3,v4]]|
+---+------------------+