地图内的Scala Spark过滤器

时间:2019-04-24 21:52:57

标签: scala apache-spark

我想在映射时有效地过滤RDD。有可能吗?

这是我想做的伪代码:

for element in rdd:
    val opt = f(element)
    if (opt.nonEmpty) add_pair(opt.get, element)

这是在Scala Spark中实现伪代码的骇人方式:

rdd.map(element => (
    f(element).getOrElse(99),
    element
)).filter(tuple => tuple._1 != 99)

我无法找到干净的语法来做到这一点,因此我首先映射了所有元素,然后滤除了我不想要的元素。请注意,潜在的昂贵通话f(element)仅计算一次。如果我要在映射之前过滤元素(看起来会更干净),那么我最终会两次调用f,这样效率低下。

请不要将其标记为重复项。尽管存在类似的问题,但他们都没有实际回答这个问题。例如,this个潜在重复项将调用f两次,这样效率低下,因此无法回答此问题。

2 个答案:

答案 0 :(得分:4)

您可以只使用flatMap

//let's say your f returns Some(x*2) for even number and None for odd
def f(n: Int): Option[Int] = if (n % 2) Some(n*2) else None 

val rdd = sc.parallelize(List(1,2,3,4))
rdd.flatMap(f) // 4,8

// rdd.flatMap(f) or rdd.flatMap(f(_)) or rdd.flatMap(e => f(e))

如果您需要进一步传递元组并进行过滤,则只需使用嵌套的map

rdd.flatMap(e => f(e).map((_,e))) //(4,2),(8,4)

答案 1 :(得分:3)

您可以使用mapPartitions进行过滤并进行昂贵的计算。

rdd.mapPartitions( elements => 
  elements
      .map(element => (f(element),element))
      .filter(tuple => tuple._1.isDefined)
)

请注意,在这段代码中,filter是本机Scala收集方法,而不是Spark RDD过滤器。

或者,您也可以flatMap函数的结果

rdd.flatMap(element => f(element).map(result => (result,element)))