定义一个UDF以根据外部列表过滤WrappedArray元素

时间:2019-05-14 11:26:20

标签: scala apache-spark

我正在尝试定义一种从DF中的WrappedArrays筛选元素的方法。该过滤器基于外部元素列表。

寻找解决方案,我发现了这个question。这非常相似,但似乎对我不起作用。我正在使用Spark 2.4.0。这是我的代码:

val df = sc.parallelize(Array((1, Seq("s", "v", "r")),(2, Seq("r", "a", "v")),(3, Seq("s", "r", "t")))).toDF("foo","bar")


def filterItems(flist: Seq[String]) = udf {
  (recs: Seq[String]) => recs match {
    case null => Seq.empty[String]
    case recs => recs.intersect(flist)
  }}

df.withColumn("filtercol", filterItems(Seq("s", "v"))(col("bar"))).show(5)

我的预期结果是:

+---+---------+---------+ 
|foo| bar|filtercol| 
+---+---------+---------+ 
| 1 |[s, v, r]|   [s, v]| 
| 2 |[r, a, v]|      [v]| 
| 3| [s, r, t]|      [s]| 
+---+---------+---------+

但是我遇到了这个错误:

java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List$SerializationProxy to field org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$dependencies_ of type scala.collection.Seq in instance of org.apache.spark.rdd.MapPartitionsRDD

2 个答案:

答案 0 :(得分:3)

实际上,您无需花费太多精力就可以使用Spark 2.4中的内置功能:

import org.apache.spark.sql.functions.{array_intersect, array, lit}

val df = sc.parallelize(Array((1, Seq("s", "v", "r")),(2, Seq("r", "a", "v")),(3, Seq("s", "r", "t")))).toDF("foo","bar")

val ar = Seq("s", "v").map(lit(_))
df.withColumn("filtercol", array_intersect($"bar", array(ar:_*))).show

输出:

+---+---------+---------+
|foo|      bar|filtercol|
+---+---------+---------+
|  1|[s, v, r]|   [s, v]|
|  2|[r, a, v]|      [v]|
|  3|[s, r, t]|      [s]|
+---+---------+---------+

唯一棘手的部分是Seq("s", "v").map(lit(_)),它将把每个字符串映射到lit(i)中。 intersection函数接受两个数组。第一个是bar列的值。第二个是使用array(ar:_*)动态创建的,其中将包含lit(i)的值。

答案 1 :(得分:0)

如果将ArrayType的属性传递到UDF中,它将作为WrappedArray的实例到达,该实例不是一个List。因此,您应该将recs的类型更改为SeqIndexedSeqWrappedArray,通常我只使用普通的Seq

def filterItems(flist: List[String]) = udf {
  (recs: Seq[String]) => recs match {
    case null => Seq.empty[String]
    case recs => recs.intersect(flist)
  }}