实现List#flatMap

时间:2013-12-07 01:55:20

标签: scala functional-programming

是否有更好的功能方式来编写flatMap

def flatMap[A,B](list: List[A])(f: A => List[B]): List[B] =
    list.map(x => f(x)).flatten

从概念上讲,我理解flatMap flatten

3 个答案:

答案 0 :(得分:11)

另一种方法:

def flatMap[A, B](list: List[A])(f: A => List[B]): List[B] =
  list.foldLeft(List[B]())(_ ++ f(_))

我不知道“更好”。 (如果我们开始谈论有效的实施,那就是另一种蠕虫......)

答案 1 :(得分:3)

只是为了充实答案,您还可以使用模式匹配将其定义为递归函数:

def flatMap[A, B](list: List[A])(f: A => List[B]): List[B] = list match {
  case (x::xs) => f(x) ++ flatMap(xs)(f)
  case _ => Nil
}

或明确尾递归:

import scala.annotation.tailrec

def flatMapTailRec[A, B](list: List[A])(f: A => List[B]): List[B] = {
  @tailrec
  def _flatMap(result: List[B])(input: List[A])(f: A => List[B]): List[B] = input match {
    case (x::xs) => _flatMap(f(x) ++ result)(xs)(f)
    case _ => result
  }
  _flatMap(List[B]())(list)(f)
}

我使用以下示例输入进行了一些快速,非严格的基准测试:

val input = (0 to 1000).map(_ => (0 to 1000).toList).toList

从最快到最慢:

  • flatMap(input)(x => x)
    • 0.02069937453
  • flatMapTailRec(input)(x => x)
    • 0.02335651054
  • input.flatMap(x => x)
    • 0.0297564358
  • flatMapFoldLeft(input)(x => x)
    • 12.940458234

我有点惊讶foldLeft比其他人慢得多。有兴趣看看如何在源中实际定义flatMap。我试着看自己,但目前还有很多东西要经过> _>。

编辑:正如Daniel Sobral在另一个答案的评论中指出的那样,这些实现仅限于List[A]。您可以编写一个更通用的实现,适用于任何可映射类型。代码很快就会变得复杂得多。

答案 2 :(得分:0)

您的解决方案已经非常实用,并强调flatMap名称的来源。请注意x => f(x)f,因此归结为:

list.map(f).flatten

由于串联引起的二次行为,使用foldLeft是一个糟糕的主意。例如。 (((a++b)++c)++d)++List()会在a上迭代4次,在b上迭代3次等。

很好foldRight