在Scala中递归创建一个包含多个谓词的Spark过滤器

时间:2018-06-15 13:08:06

标签: scala apache-spark

我有一个包含一堆值的数据框

val df = List(
  (2017, 1, 1234),
  (2017, 2, 1234),
  (2017, 3, 1234),
  (2017, 4, 1234),
  (2018, 1, 12345),
  (2018, 2, 12346),
  (2018, 3, 12347),
  (2018, 4, 12348)
).toDF("year", "month", "employeeCount")
  

df:org.apache.spark.sql.DataFrame = [year:int,month:int,employeeCount:int]

我希望按(年,月)对列表过滤该数据框:

val filterValues = List((2018, 1), (2018, 2))

我可以轻易地欺骗并编写实现它的代码:

df.filter(
  (col("year") === 2018 && col("month") === 1) || 
  (col("year") === 2018 && col("month") === 2)
).show

但当然这并不令人满意,因为filterValues可能会改变,我希望将其基于该列表中的任何内容。

是否可以动态构建我的filter_expression,然后将其传递给df.filter(filter_expression)?我无法弄清楚如何。

3 个答案:

答案 0 :(得分:3)

基于your comment

  

想象某人从命令行调用此函数,例如--filterColumns"年,月" --filterValues" 2018 | 1,2018 | 2"

val filterValues = "2018|1,2018|2"
val filterColumns = "year,month"

您可以获得列列表

val colnames = filterColumns.split(',')

将数据转换为本地Dataset(需要时添加schema):

val filter = spark.read.option("delimiter", "|")
  .csv(filterValues.split(',').toSeq.toDS)
  .toDF(colnames: _*)

和半连接:

df.join(filter, colnames, "left_semi").show
// +----+-----+-------------+             
// |year|month|employeeCount|
// +----+-----+-------------+
// |2018|    1|        12345|
// |2018|    2|        12346|
// +----+-----+-------------+

像这样的表达也应该有效:

import org.apache.spark.sql.functions._

val pred = filterValues
  .split(",")
  .map(x => colnames.zip(x.split('|'))
                    .map { case (c, v) => col(c) === v }
                    .reduce(_ && _))
  .reduce(_ || _)

df.where(pred).show
// +----+-----+-------------+
// |year|month|employeeCount|
// +----+-----+-------------+
// |2018|    1|        12345|
// |2018|    2|        12346|
// +----+-----+-------------+

但如果需要某种类型的转换,则需要更多的工作。

答案 1 :(得分:1)

您始终可以使用udf函数

来执行此操作
val filterValues = List((2018, 1), (2018, 2))

import org.apache.spark.sql.functions._
def filterUdf = udf((year:Int, month:Int) => filterValues.exists(x => x._1 == year && x._2 == month))

df.filter(filterUdf(col("year"), col("month"))).show(false)

<强>更新

您评论为

  
    

我的意思是要在运行时从其他地方提供要过滤的列的列表(以及相应值的相应列表)。

  

为此你也会提供列名列表,所以解决方案就像下面的

val filterValues = List((2018, 1), (2018, 2))
val filterColumns = List("year", "month")

import org.apache.spark.sql.functions._
def filterUdf = udf((unknown: Seq[Int]) => filterValues.exists(x => !x.productIterator.toList.zip(unknown).map(y => y._1 == y._2).contains(false)))

df.filter(filterUdf(array(filterColumns.map(col): _*))).show(false)

答案 2 :(得分:0)

你可以像这样建立你的filter_expression:

val df = List(
  (2017, 1, 1234),
  (2017, 2, 1234),
  (2017, 3, 1234),
  (2017, 4, 1234),
  (2018, 1, 12345),
  (2018, 2, 12346),
  (2018, 3, 12347),
  (2018, 4, 12348)
).toDF("year", "month", "employeeCount")

val filterValues = List((2018, 1), (2018, 2))

val filter_expession = filterValues
  .map{case (y,m) => col("year") === y and col("month") === m}
  .reduce(_ || _)

df
  .filter(filter_expession)
  .show()

+----+-----+-------------+
|year|month|employeeCount|
+----+-----+-------------+
|2018|    1|        12345|
|2018|    2|        12346|
+----+-----+-------------+