爆炸功能的反作用

时间:2019-09-29 09:44:02

标签: scala apache-spark apache-spark-sql

在使用spark-2.4的scala中,我想过滤列中数组中的值。

来自

+---+------------+
| id|      letter|
+---+------------+
|  1|[x, xxx, xx]|
|  2|[yy, y, yyy]|
+---+------------+

收件人

+---+-------+
| id| letter|
+---+-------+
|  1|[x, xx]|
|  2|[yy, y]|
+---+-------+

我考虑过使用explode + filter

val res = Seq(("1", Array("x", "xxx", "xx")), ("2", Array("yy", "y", "yyy"))).toDF("id", "letter")
res.withColumn("tmp", explode(col("letter"))).filter(length(col("tmp")) < 3).drop(col("letter")).show()

我得到了

+---+---+
| id|tmp|
+---+---+
|  1|  x|
|  1| xx|
|  2| yy|
|  2|  y|
+---+---+

如何按ID zip / groupBy回来?

或者也许有更好,更优化的解决方案?

2 个答案:

答案 0 :(得分:6)

您可以在Spark 2.4中过滤不带explode()的数组:

res.withColumn("letter", expr("filter(letter, x -> length(x) < 3)")).show()

输出:

+---+-------+
| id| letter|
+---+-------+
|  1|[x, xx]|
|  2|[yy, y]|
+---+-------+

答案 1 :(得分:2)

在Spark 2.4+中,高阶函数是行之有效的方法(filter),或者使用collect_list

res.withColumn("tmp",explode(col("letter")))
  .filter(length(col("tmp")) < 3)
  .drop(col("letter"))
  // aggregate back
  .groupBy($"id")
  .agg(collect_list($"tmp").as("letter"))
  .show()

给予:

+---+-------+
| id| letter|
+---+-------+
|  1|[x, xx]|
|  2|[yy, y]|
+---+-------+

由于这会引入随机播放,因此最好使用UDF:

def filter_arr(maxLength:Int)= udf((arr:Seq[String]) => arr.filter(str => str.size<=maxLength))

res
  .select($"id",filter_arr(maxLength = 2)($"letter").as("letter"))
  .show()

给予:

+---+-------+
| id| letter|
+---+-------+
|  1|[x, xx]|
|  2|[yy, y]|
+---+-------+