scala中mergesort的StackOverflowError

时间:2014-10-24 17:09:47

标签: scala recursion scala-2.10

我已经实现了以下mergesort代码。但是当整数的数量高达100000时,我在算法的合并过程中得到stackOverFlowError。我正在使用模式匹配和递归进行合并过程。我理解使用递归进行合并过程并不是最佳的,因为这个输入的深度将高达50000.但是因为我使用scala,我期待一些编译器优化使递归调用迭代,因为这些是尾递归调用。你能帮我理解为什么我仍然在下面的代码中得到StackOverFlowerror吗?请提供有关如何在scala中更有效地编写此内容的输入信息? 以下是代码:

package common

object Merge {
  def main(args: Array[String]) = {
    val source = scala.io.Source.fromFile("IntegerArray.txt")
    val data = source.getLines.map {line => line.toInt}.toList
    println(data.length)
    val res = mergeSort(data)
    println(res)
  }
  def mergeSort(data: List[Int]): List[Int] = {
    if(data.length <= 1) {data }
    else {
      val mid = (data.length)/2
      val (l, r) = data.splitAt(mid)
      val l1 = mergeSort(l)
      val l2 = mergeSort(r)
      merge(l1, l2)
    }
  }

  def merge(l: List[Int], r: List[Int]): List[Int] = {
    l match {
      case List() => r
      case x::xs => {
        r match {
          case List() => l
          case y::ys => {
            if(x<y) {
              x :: merge(xs, r)
            } else {
              y :: merge(l, ys)
            }
          }
        }
      }
    }
  }
}

以下是我得到的例外情况:

Exception in thread "main" java.lang.StackOverflowError
    at common.Merge$.merge(Merge.scala:30)
    at common.Merge$.merge(Merge.scala:30)
    at common.Merge$.merge(Merge.scala:30)
    at common.Merge$.merge(Merge.scala:32)
    at common.Merge$.merge(Merge.scala:32)
    at common.Merge$.merge(Merge.scala:30)
    at common.Merge$.merge(Merge.scala:32)
    at common.Merge$.merge(Merge.scala:30)
    at common.Merge$.merge(Merge.scala:32)
    at common.Merge$.merge(Merge.scala:32)
    at common.Merge$.merge(Merge.scala:30)
    at common.Merge$.merge(Merge.scala:30)
    at common.Merge$.merge(Merge.scala:32)
    at common.Merge$.merge(Merge.scala:32)
    at common.Merge$.merge(Merge.scala:32)
    at common.Merge$.merge(Merge.scala:30)
    at common.Merge$.merge(Merge.scala:32)
    at common.Merge$.merge(Merge.scala:32)
    at common.Merge$.merge(Merge.scala:32)
    at common.Merge$.merge(Merge.scala:30)

2 个答案:

答案 0 :(得分:3)

合并排序需要递归,但这不是问题,因为它是O(log n)。 merge方法应优化为循环,因为它是O(n)。

TailRec优化仅在递归调用是最后一个命令时才有效,在您的情况下,最后一个命令是列表连接(或前置)。

您可以添加@tailrec注释。编译器将始终尝试优化,但这样它会让你知道它是否能够做到。

  merge(l1, l2, Nil)
  ...

  @tailrec
  def merge(l: List[Int], r: List[Int], acc: List[Int]): List[Int] = {
    l match {
      case List() => acc ::: r
      case x::xs => {
        r match {
          case List() => acc ::: l
          case y::ys => {
            val (item, lTail, rTail) =
               if(x<y) (x, xs, r)
               else (y, l, ys)
            merge(lTail, rTail, acc:::List(item))
          }
        }
      }
    }
  }

策略是对结果使用累加器,在基本情况下返回累加器而不是Nil列表。这样编译器就可以进行TailRec优化。

还要考虑编写如下代码:

  @tailrec
  def merge(l: List[Int], r: List[Int], acc: List[Int]): List[Int] = {
    if (l.isEmpty) acc ::: r
    else if (r.isEmpty) acc ::: l
    else {
      val (item, lTail, rTail) =
        if (l.head<r.head) (l.head, l.tail, r)
        else (r.head, l, r.tail)
      merge(lTail, rTail, acc:::List(item))
    }
  }

我发现这种方式更简单易懂。

另请注意,尾递归调用并不需要只有一个,如前面的示例所示,因此只要递归调用是最后一个,您就可以返回到之前的if-else调用每个分支:

  @tailrec
  def merge(l: List[Int], r: List[Int], acc: List[Int]): List[Int] = {
    if (l.isEmpty) acc ::: r
    else if (r.isEmpty) acc ::: l
    else {
      if (l.head < r.head)
        merge(l.tail, r, acc ::: List(l.head))
      else
        merge(l, r.tail, acc ::: List(r.head))
    }
  }

答案 1 :(得分:1)

你也可以匹配元组:

def merge(l: List[Int], r: List[Int], acc: List[Int]): List[Int] = (l,r) match {
    case (lh :: lt, rh :: rt) => 
        if (lh < rh) 
             merge(lt, r, lh :: acc) 
        else 
             merge(l, rt, rh :: acc)
    case _ => acc.reverse ::: l ::: r
}

如果您反过来累积,您的运行时间将不依赖于:::的实施效率,您将获得O(n)