将Spark中的多个ArrayType列合并为一个ArrayType列

时间:2018-08-30 06:20:13

标签: scala apache-spark apache-spark-sql

我想在spark中合并多个ArrayType [StringType]列以创建一个ArrayType [StringType]。为了合并两列,我在这里找到了解决方案:

Merge two spark sql columns of type Array[string] into a new Array[string] column

但是,如果我在编译时不知道列数,该如何合并。在运行时,我将知道要合并的所有列的名称。

一个选择是使用在上面的stackoverflow问题中定义的UDF,在一个循环中多次添加两列。但这涉及对整个数据帧的多次读取。有没有办法一口气做到这一点?

+------+------+---------+
| col1 | col2 | combined|
+------+------+---------+
| [a,b]| [i,j]|[a,b,i,j]|
| [c,d]| [k,l]|[c,d,k,l]|
| [e,f]| [m,n]|[e,f,m,n]|
| [g,h]| [o,p]|[g,h,o,p]|
+------+----+-----------+

2 个答案:

答案 0 :(得分:0)

val arrStr: Array[String] = Array("col1", "col2")

val arrCol: Array[Column] = arrString.map(c => df(c))

val assembleFunc = udf { r: Row => assemble(r.toSeq: _*)}

val outputDf = df.select(col("*"), assembleFunc(struct(arrCol: 
_*)).as("combined"))

def assemble(rowEntity: Any*): 
                    collection.mutable.WrappedArray[String] = {

 var outputArray = 
 rowEntity(0).asInstanceOf[collection.mutable.WrappedArray[String]]

  rowEntity.drop(1).foreach {
    case v: collection.mutable.WrappedArray[String] =>
      outputArray ++= v
    case null =>
      throw new SparkException("Values to assemble cannot be 
      null.")
    case o =>
      throw new SparkException(s"$o of type ${o.getClass.getName} 
      is not supported.")
 }

outputArray
}

outputDf.show(false)    

答案 1 :(得分:-1)

  1. 处理数据框架构,并获取类型为ArrayType[StringType]的所有列。

  2. 使用前两列的functions.array_union创建一个新的数据框

  3. 反复浏览其余的列,并将它们添加到合并的列中

>>>from pyspark import Row
>>>from pyspark.sql.functions import array_union
>>>df = spark.createDataFrame([Row(col1=['aa1', 'bb1'], 
                                col2=['aa2', 'bb2'],
                                col3=['aa3', 'bb3'], 
                                col4= ['a', 'ee'], foo="bar"
                               )])
>>>df.show()
+----------+----------+----------+-------+---+
|      col1|      col2|      col3|   col4|foo|
+----------+----------+----------+-------+---+
|[aa1, bb1]|[aa2, bb2]|[aa3, bb3]|[a, ee]|bar|
+----------+----------+----------+-------+---+
>>>cols = [col_.name for col_ in df.schema 
...       if col_.dataType == ArrayType(StringType()) 
...        or col_.dataType == ArrayType(StringType(), False)
...       ]
>>>print(cols)
['col1', 'col2', 'col3', 'col4']
>>>
>>>final_df = df.withColumn("combined", array_union(cols[:2][0], cols[:2][1]))
>>>
>>>for col_ in cols[2:]:
...    final_df = final_df.withColumn("combined", array_union(col('combined'), col(col_)))
>>>
>>>final_df.select("combined").show(truncate=False)
+-------------------------------------+
|combined                             |
+-------------------------------------+
|[aa1, bb1, aa2, bb2, aa3, bb3, a, ee]|
+-------------------------------------+