根据提供的列表过滤数组列

时间:2018-01-14 11:51:01

标签: scala apache-spark apache-spark-sql

我在数据框中有以下类型:

 root
 |-- id: string (nullable = true)
 |-- items: array (nullable = true)
 |    |-- element: string (containsNull = true)

输入:

val rawData = Seq(("id1",Array("item1","item2","item3","item4")),("id2",Array("item1","item2","item3")))
val data = spark.createDataFrame(rawData)

和项目清单:

 val filter_list = List("item1", "item2")

我想过滤掉filter_list中不存在的项目,类似于array_contains的工作方式,但它不在提供的字符串列表上工作,只有一个值。

所以输出看起来像这样:

val rawData = Seq(("id1",Array("item1","item2")),("id2",Array("item1","item2")))
val data = spark.createDataFrame(rawData)

我尝试用以下UDF解决这个问题,但我可能在Scala和Spark之间混合使用类型:

def filterItems(flist: List[String]) = udf {
  (recs: List[String]) => recs.filter(item => flist.contains(item))
}

我正在使用Spark 2.2

谢谢!

1 个答案:

答案 0 :(得分:2)

你的代码几乎是正确的。您所要做的就是将List替换为Seq

def filterItems(flist: List[String]) = udf {
  (recs: Seq[String]) => recs.filter(item => flist.contains(item))
}

将签名从List[String] => UserDefinedFunction更改为SeqString] => UserDefinedFunction也是有意义的,但这不是必需的。

参考SQL Programming Guide - Data Types