我试图在Scala中编写一个尾递归快速排序,通过构建一个延续来工作,而不使用蹦床。到目前为止,我有以下内容:
object QuickSort {
def sort[A: Ordering](toSort: Seq[A]): Seq[A] = {
val ordering = implicitly[Ordering[A]]
import ordering._
@scala.annotation.tailrec
def step(list: Seq[A], conts: List[Seq[A] => Seq[A]]): Seq[A] = list match {
case s if s.length <= 1 => conts.foldLeft(s) { case (acc, next) => next(acc) }
case Seq(h, tail @ _*) => {
val (less, greater) = tail.partition(_ < h)
step(less, { sortedLess: Seq[A] =>
/*
Can't use
step(greater, sortedGreater => (sortedLess :+ h) ++ sortedGreater)
and keep the tailrec annotation
*/
(sortedLess :+ h) ++ sort(greater)
} +: conts)
}
}
step(toSort, Nil)
}
}
在我的计算机上,上面的实现使用至少4000000个元素的随机序列,但我对它有疑问。具体来说,我想知道:
@tailrec
进行编译,但对sort(greater)
的调用似乎有点可疑。 要清楚,我已经看过this related question讨论如何使用trampolines(我知道如何使用)或你自己的显式堆栈实现尾递归快速排序,但我特别想要知道是否以及如何以不同的方式完成它。
答案 0 :(得分:1)
sort(greater)
的调用停留在延续中,它存在于堆而不是堆栈中。考虑到形状错误的足够大的问题,你可能会破坏堆,但这比吹掉堆栈要多得多。答案 1 :(得分:0)
不,您的代码不是堆栈安全的。 sort
会在更大程度上再次调用step
和step
来调用sort
,因此它不是堆栈安全的。
要做cps
,请从普通表单开始:
def sort(list: Seq[A]): Seq[A] = list match {
case s if s.length <= 1 => s
case Seq(h, tail @ _*) => {
val (less, greater) = tail.partition(_ < h)
val l = sort(less)
val g = sort(greater)
(l :+ Seq(h)) ++ g
}
}
然后将其翻译为cps,非常简单:
def sort(list: Seq[A], cont: Seq[A] => Unit): Unit = list match {
case s if s.length <= 1 => cont(s)
case Seq(h, tail @ _*) => {
val (less, greater) = tail.partition(_ < h)
sort(less, { l =>
sort(greater, { g =>
cont((l :+ Seq(h)) ++ g)
})
})
}
}
注意:
Unit
Unit
最后,将其换成普通形式:
def quicksort(list: Seq[A]): Seq[A] = {
var result
sort(list, { r => result = r })
result
}
注意:CPS转换使每个函数都进行尾调用(NOT tail-rec),因为scala不支持尾调用优化,因此您需要手动进行尾调用:
trait TCF[T] {
def result: Option[T]
def apply(): TCF[T]
}
private def tco[T](f: => TCF[T]): TCF[T] = new TCF[T] {
def result = None
def apply() = f
}
def quicksort[A: Ordering](list: Seq[A]): Seq[A] = {
case class Result(r: Seq[A]) extends Exception
Iterator.iterate(sort(list, { r: Seq[A] =>
new TCF[Seq[A]] {
def result = Some(r)
def apply() = throw new RuntimeException("unreachable")
}
}))(c => c()).dropWhile(_.result == None).next().result.get
}
private def sort[A: Ordering](list: Seq[A], cont: Seq[A] => TCF[Seq[A]]): TCF[Seq[A]] = {
val ordering = implicitly[Ordering[A]]
import ordering._
list match {
case s if s.length <= 1 => tco(cont(s))
case Seq(h, tail@_*) => {
val (less, greater) = tail.partition(_ < h)
tco(sort(less, { l: Seq[A] =>
tco(sort(greater, { g: Seq[A] =>
tco(cont((l :+ h) ++ g))
}))
}))
}
}
}
试试here。
答案 2 :(得分:0)
我决定使用JVisualVM来查看我在问题中实现的调用树,并发现由于++ step(greater)
调用它正在占用堆栈。我认为很难达到堆栈溢出的程度,因为列表每次都被分割一半,较小的一半以尾递归,堆栈安全的方式递归排序。 / p>
在考虑了这一点后,我提出了以下修订后的解决方案(试试here)
object QuickSort {
def sort[A: Ordering](toSort: Seq[A]): Seq[A] = {
val ordering = implicitly[Ordering[A]]
import ordering._
// Aliasing allows us to be tail-recursive
def step2(list: Seq[A], conts: Vector[Seq[A] => Seq[A]]): Seq[A] = step(list, conts)
@scala.annotation.tailrec
def step(list: Seq[A], conts: Vector[Seq[A] => Seq[A]]): Seq[A] = list match {
case s if s.length <= 1 => conts.foldLeft(s) { case (acc, next) => next(acc) }
case Seq(h, tail @ _*) => {
val (less, greater) = tail.partition(_ < h)
val nextConts: Vector[Seq[A] => Seq[A]] =
{ sortedLess: Seq[A] =>
sortedLess :+ h
} +: { appendedLess: Seq[A] =>
step2(greater, Vector({ sortedGreater => appendedLess ++ sortedGreater }))
} +: conts
step(less, nextConts)
}
}
step(toSort, Vector.empty)
}
}
主要区别是:
step2
step
别名来保持@tailrec
注释的快乐。 step(greater)
累加器,而不是在继续排序较少的分区中调用conts
,我们将已排序的较少分区附加到已排序的较大分区。我想你可以说这个累加器只是堆上的堆栈.. 有趣的是,这个解决方案变得非常快,击败了linked question中的Scalaz蹦床解决方案。将它与上面的半堆栈解决方案进行比较,在排序100万个元素时速度大约慢30 ns,但这是错误的。
[info] Benchmark (sortLength) Mode Cnt Score Error Units
[info] SortBenchmarks.sort 100 avgt 30 0.034 ± 0.001 ms/op
[info] SortBenchmarks.sort 10000 avgt 30 6.258 ± 0.072 ms/op
[info] SortBenchmarks.sort 1000000 avgt 30 1016.849 ± 23.572 ms/op
[info] SortBenchmarks.scalazSort 100 avgt 30 0.070 ± 0.001 ms/op
[info] SortBenchmarks.scalazSort 10000 avgt 30 10.426 ± 0.092 ms/op
[info] SortBenchmarks.scalazSort 1000000 avgt 30 1635.693 ± 68.068 ms/op