我已经实现了以下mergesort代码。但是当整数的数量高达100000时,我在算法的合并过程中得到stackOverFlowError。我正在使用模式匹配和递归进行合并过程。我理解使用递归进行合并过程并不是最佳的,因为这个输入的深度将高达50000.但是因为我使用scala,我期待一些编译器优化使递归调用迭代,因为这些是尾递归调用。你能帮我理解为什么我仍然在下面的代码中得到StackOverFlowerror吗?请提供有关如何在scala中更有效地编写此内容的输入信息? 以下是代码:
package common
object Merge {
def main(args: Array[String]) = {
val source = scala.io.Source.fromFile("IntegerArray.txt")
val data = source.getLines.map {line => line.toInt}.toList
println(data.length)
val res = mergeSort(data)
println(res)
}
def mergeSort(data: List[Int]): List[Int] = {
if(data.length <= 1) {data }
else {
val mid = (data.length)/2
val (l, r) = data.splitAt(mid)
val l1 = mergeSort(l)
val l2 = mergeSort(r)
merge(l1, l2)
}
}
def merge(l: List[Int], r: List[Int]): List[Int] = {
l match {
case List() => r
case x::xs => {
r match {
case List() => l
case y::ys => {
if(x<y) {
x :: merge(xs, r)
} else {
y :: merge(l, ys)
}
}
}
}
}
}
}
以下是我得到的例外情况:
Exception in thread "main" java.lang.StackOverflowError
at common.Merge$.merge(Merge.scala:30)
at common.Merge$.merge(Merge.scala:30)
at common.Merge$.merge(Merge.scala:30)
at common.Merge$.merge(Merge.scala:32)
at common.Merge$.merge(Merge.scala:32)
at common.Merge$.merge(Merge.scala:30)
at common.Merge$.merge(Merge.scala:32)
at common.Merge$.merge(Merge.scala:30)
at common.Merge$.merge(Merge.scala:32)
at common.Merge$.merge(Merge.scala:32)
at common.Merge$.merge(Merge.scala:30)
at common.Merge$.merge(Merge.scala:30)
at common.Merge$.merge(Merge.scala:32)
at common.Merge$.merge(Merge.scala:32)
at common.Merge$.merge(Merge.scala:32)
at common.Merge$.merge(Merge.scala:30)
at common.Merge$.merge(Merge.scala:32)
at common.Merge$.merge(Merge.scala:32)
at common.Merge$.merge(Merge.scala:32)
at common.Merge$.merge(Merge.scala:30)
答案 0 :(得分:3)
合并排序需要递归,但这不是问题,因为它是O(log n)。 merge
方法应优化为循环,因为它是O(n)。
TailRec优化仅在递归调用是最后一个命令时才有效,在您的情况下,最后一个命令是列表连接(或前置)。
您可以添加@tailrec
注释。编译器将始终尝试优化,但这样它会让你知道它是否能够做到。
merge(l1, l2, Nil)
...
@tailrec
def merge(l: List[Int], r: List[Int], acc: List[Int]): List[Int] = {
l match {
case List() => acc ::: r
case x::xs => {
r match {
case List() => acc ::: l
case y::ys => {
val (item, lTail, rTail) =
if(x<y) (x, xs, r)
else (y, l, ys)
merge(lTail, rTail, acc:::List(item))
}
}
}
}
}
策略是对结果使用累加器,在基本情况下返回累加器而不是Nil列表。这样编译器就可以进行TailRec优化。
还要考虑编写如下代码:
@tailrec
def merge(l: List[Int], r: List[Int], acc: List[Int]): List[Int] = {
if (l.isEmpty) acc ::: r
else if (r.isEmpty) acc ::: l
else {
val (item, lTail, rTail) =
if (l.head<r.head) (l.head, l.tail, r)
else (r.head, l, r.tail)
merge(lTail, rTail, acc:::List(item))
}
}
我发现这种方式更简单易懂。
另请注意,尾递归调用并不需要只有一个,如前面的示例所示,因此只要递归调用是最后一个,您就可以返回到之前的if-else调用每个分支:
@tailrec
def merge(l: List[Int], r: List[Int], acc: List[Int]): List[Int] = {
if (l.isEmpty) acc ::: r
else if (r.isEmpty) acc ::: l
else {
if (l.head < r.head)
merge(l.tail, r, acc ::: List(l.head))
else
merge(l, r.tail, acc ::: List(r.head))
}
}
答案 1 :(得分:1)
你也可以匹配元组:
def merge(l: List[Int], r: List[Int], acc: List[Int]): List[Int] = (l,r) match {
case (lh :: lt, rh :: rt) =>
if (lh < rh)
merge(lt, r, lh :: acc)
else
merge(l, rt, rh :: acc)
case _ => acc.reverse ::: l ::: r
}
如果您反过来累积,您的运行时间将不依赖于:::
的实施效率,您将获得O(n)