数组中的Spark数据框UDF过滤器

时间:2018-07-15 17:02:20

标签: apache-spark dataframe

我想定义一个UDFs函数来过滤Spark中的DataFrame。 我想过滤每一列中的数组元素。

示例:过滤器元素以“ Z”开头,请删除数组中所有不以Z开头的元素

Original Data
+---+-------------+
| _1|           _2|
+---+-------------+
|id1|[AA,BB,CC,Z12]|
|id2|[AA,ZA,CC,Z3]|
|id2|[Z2,XX,CC,A2]|
+---+-------------+
Expected result
+---+-----------+
| _1| _2        |
+---+-----------+
|id1| [Z12]     |
|id2| [ZA,Z3]   |
|id2| [Z2]      |
+---+-----------+
Current result
+---+--------------+
| _1| _2           |
+---+--------------+
|id1| []           |
|id2| []           |
|id2| [Z2,XX,CC,A2]|
+---+--------------+

当前代码

def filterArray = udf((recs: Seq[String]) =>{
    recs.filter(_.startsWith("Z"))
})

val rawData = Seq(("id1",Array("AA,BB,CC,Z12")),("id2",Array("AA,ZA,CC,Z3")),("id2",Array("AA,XX,CC,A2")))
var test = spark.createDataFrame(rawData)
test.show(4)
test = test.withColumn("_2", filterArray(test("_2")))
test.show(4)

1 个答案:

答案 0 :(得分:0)

问题在于您的数组都只有1个元素,您应该首先拆分包含的字符串,然后进行过滤:

def filterArray = udf((recs: Seq[String]) =>{
  recs.flatMap(_.split(",")).filter(_.startsWith("Z"))
})

那么你就得到

+---+--------+
| _1|      _2|
+---+--------+
|id1|   [Z12]|
|id2|[ZA, Z3]|
|id2|      []|
+---+--------+

如果数据定义如下,则可以保留当前的UDF:

val rawData = Seq(
   ("id1",Array("AA","BB","CC","Z12")), 
   ("id2",Array("AA","ZA","CC","Z3")), 
   ("id2",Array("AA","XX","CC","A2"))
 )