仅将每行的非空列收集到数组中

时间:2019-11-07 17:30:48

标签: apache-spark apache-spark-sql

困难在于,我试图尽可能避免使用UDF。

我有一个数据集“ wordsDS”,其中包含许多空值:

+------+------+------+------+
|word_0|word_1|word_2|word_3|
+------+------+------+------+
|     a|     b|  null|     d|
|  null|     f|     m|  null|
|  null|  null|     d|  null|
+--------------+------+-----|

我需要收集要排列的每一行的所有列。我事先不知道列数,所以我正在使用columns()方法。

groupedQueries = wordsDS.withColumn("collected",
      functions.array(Arrays.stream(wordsDS.columns())
               .map(functions::col).toArray(Column[]::new)));;

但是这种方法会产生空元素

+--------------------+
|           collected|
+--------------------+
|           [a, b,,d]|
|          [, f, m,,]|
|            [,, d,,]|
+--------------------+

相反,我需要以下结果:

+--------------------+
|           collected|
+--------------------+
|           [a, b, d]|
|              [f, m]|
|                 [d]|
+--------------------+

因此,基本上,我需要收集每一行的所有列以符合以下要求:

  1. 结果数组不包含空元素。
  2. 不知道预先的列数。

我也采用了对数据集的“收集”列进行空值过滤的方法,但是除了UDF之外,别无其他选择。我正在尝试避免UDF,以免降低性能,如果有人可以建议一种方法以尽可能少的开销为空值过滤数据集的“收集”列,那将非常有帮助

2 个答案:

答案 0 :(得分:0)

您可以使用array("*")将所有元素放入1个数组中,然后使用array_except(需要Spark 2.4+)来过滤出空值:

df
  .select(array_except(array("*"),array(lit(null))).as("collected"))
  .show()

给予

+---------+
|collected|
+---------+
|[a, b, d]|
|   [f, m]|
|      [d]|
+---------+

答案 1 :(得分:0)

火花<2.0,您可以使用def删除null

scala> var df = Seq(("a",  "b",  "null",  "d"),("null",  "f",  "m",  "null"),("null",  "null",  "d",  "null")).toDF("word_0","word_1","word_2","word_3")


scala> def arrayNullFilter = udf((arr: Seq[String]) => arr.filter(x=>x != "null"))

scala> df.select(array('*).as('all)).withColumn("test",arrayNullFilter(col("all"))).show
+--------------------+---------+
|                 all|     test|
+--------------------+---------+
|     [a, b, null, d]|[a, b, d]|
|  [null, f, m, null]|   [f, m]|
|[null, null, d, n...|      [d]|
+--------------------+---------+

希望这对您有所帮助。