Scala递归与循环:性能和运行时考虑因素

时间:2013-02-28 14:56:25

标签: performance scala stack-overflow tail-recursion

我写了一个天真的测试平台来测量三种阶乘实现的性能:基于循环,非尾递归和尾递归。

  

令我惊讶的是,性能最差的是循环的(«while»预计会更有效,所以我提供了两者)这个成本   几乎是尾递归替代的两倍。

答案:修复循环实现,避免使用BigInt,因为内部«循环»按预期变得最快,因此* =运算符的表现优于最差

  

我遇到的另一个“woodoo”行为是StackOverflow   对于相同的输入,没有被抛出的异常   非尾递归实现的情况。我可以规避   StackOverlow通过逐步调用越来越大的函数   价值......我感到很疯狂:) 答案:JVM需要在启动时收敛,然后行为是连贯的和系统的

这是代码:

final object Factorial {
  type Out = BigInt

  def calculateByRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    n match {
      case _ if n == 1 => return 1
      case _ => return n * calculateByRecursion(n-1)
    }
  }

  def calculateByForLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var accumulator: Out = 1
    for (i <- 1 to n)
      accumulator = i * accumulator
    accumulator
  }

  def calculateByWhileLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var accumulator: Out = 1
    var i = 1
    while (i <= n) {
      accumulator = i * accumulator
      i += 1
    }
    accumulator
  }

  def calculateByTailRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    @tailrec def fac(n: Int, acc: Out): Out = n match {
      case _ if n == 1 => acc
      case _ => fac(n-1, n * acc)
    }

    fac(n, 1)
  }

  def calculateByTailRecursionUpward(n: Int): Out = {
    require(n>0, "n must be positive")

    @tailrec def fac(i: Int, acc: Out): Out = n match {
      case _ if i == n => n * acc
      case _ => fac(i+1, i * acc)
    }

    fac(1, 1)
  }

  def comparePerformance(n: Int) {
    def showOutput[A](msg: String, data: (Long, A), showOutput:Boolean = false) =
      showOutput match {
        case true => printf("%s returned %s in %d ms\n", msg, data._2.toString, data._1)
        case false => printf("%s in %d ms\n", msg, data._1)
    }
    def measure[A](f:()=>A): (Long, A) = {
      val start = System.currentTimeMillis
      val o = f()
      (System.currentTimeMillis - start, o)
    }
    showOutput ("By for loop", measure(()=>calculateByForLoop(n)))
    showOutput ("By while loop", measure(()=>calculateByWhileLoop(n)))
    showOutput ("By non-tail recursion", measure(()=>calculateByRecursion(n)))
    showOutput ("By tail recursion", measure(()=>calculateByTailRecursion(n)))
    showOutput ("By tail recursion upward", measure(()=>calculateByTailRecursionUpward(n)))
  }
}

以下是sbt控制台的一些输出(之前«while»实施)

scala> example.Factorial.comparePerformance(10000)
By loop in 3 ns
By non-tail recursion in >>>>> StackOverflow!!!!!… see later!!!
........

scala> example.Factorial.comparePerformance(1000)
By loop in 3 ms
By non-tail recursion in 1 ms
By tail recursion in 4 ms

scala> example.Factorial.comparePerformance(5000)
By loop in 105 ms
By non-tail recursion in 27 ms
By tail recursion in 34 ms

scala> example.Factorial.comparePerformance(10000)
By loop in 236 ms
By non-tail recursion in 106 ms     >>>> Now works!!!
By tail recursion in 127 ms

scala> example.Factorial.comparePerformance(20000)
By loop in 977 ms
By non-tail recursion in 495 ms
By tail recursion in 564 ms

scala> example.Factorial.comparePerformance(30000)
By loop in 2285 ms
By non-tail recursion in 1183 ms
By tail recursion in 1281 ms

以下是sbt控制台的一些输出(After«while»实施)

scala> example.Factorial.comparePerformance(10000)
By for loop in 252 ms
By while loop in 246 ms
By non-tail recursion in 130 ms
By tail recursion in 136 ns

scala> example.Factorial.comparePerformance(20000)
By for loop in 984 ms
By while loop in 1091 ms
By non-tail recursion in 508 ms
By tail recursion in 560 ms

以下是sbt控制台的一些输出(在«向上»尾部递归实现之后)世界恢复正常

scala> example.Factorial.comparePerformance(10000)
By for loop in 259 ms
By while loop in 229 ms
By non-tail recursion in 114 ms
By tail recursion in 119 ms
By tail recursion upward in 105 ms

scala> example.Factorial.comparePerformance(20000)
By for loop in 1053 ms
By while loop in 957 ms
By non-tail recursion in 513 ms
By tail recursion in 565 ms
By tail recursion upward in 470 ms

以下是在“循环”中修复BigInt乘法后sbt控制台的一些输出:世界完全正确

    scala> example.Factorial.comparePerformance(20000)
By for loop in 498 ms
By while loop in 502 ms
By non-tail recursion in 521 ms
By tail recursion in 611 ms
By tail recursion upward in 503 ms

BigInt开销和我的愚蠢实现掩盖了预期的行为。

谢谢你们

PS。:最后我应该将这篇文章重新命名为“关于BigInts的课程”

3 个答案:

答案 0 :(得分:12)

For循环实际上并不是完全循环;他们是对范围的理解。如果您确实需要循环,则需要使用while。 (实际上,我认为这里的BigInt乘法很重,所以它无关紧要。但你会注意到你是否乘以Int。)

此外,您使用BigInt让自己感到困惑。 BigInt越大,乘法越慢。所以你的非尾循环计算 up ,而你的尾递归循环 down ,这意味着后者有更大的数字可以乘以。

如果您解决了这两个问题,您会发现恢复了健全性:循环和尾递归速度相同,常规递归和for都较慢。 (如果JVM优化使其等效,则常规递归可能不会更慢)

(另外,堆栈溢出修复可能是因为JVM开始内联,并且可能使调用尾部递归,或者将循环展开得足够远,以便不再溢出。)

最后,你在for和while时得到的结果很差,因为你在右边而不是在左边乘以小数字。事实证明,Java的BigInt与左侧较小的数字相乘得更快。

答案 1 :(得分:1)

factorial(n)的Scala静态方法(用scala 2.12.x,java-8编码):

object Factorial {

  /*
   * For large N, it throws a stack overflow
   */
  def recursive(n:BigInt): BigInt = {
    if(n < 0) {
      throw new ArithmeticException
    } else if(n <= 1) {
      1
    } else {
      n * recursive(n - 1)
    }
  }

  /*
   * A tail recursive method is compiled to avoid stack overflow
   */
  @scala.annotation.tailrec
  def recursiveTail(n:BigInt, acc:BigInt = 1): BigInt = {
    if(n < 0) {
      throw new ArithmeticException
    } else if(n <= 1) {
      acc
    } else {
      recursiveTail(n - 1, n * acc)
    }
  }

  /*
   * A while loop
   */
  def loop(n:BigInt): BigInt = {
    if(n < 0) {
      throw new ArithmeticException
    } else if(n <= 1) {
      1
    } else {
      var acc = 1
      var idx = 1
      while(idx <= n) {
        acc = idx * acc
        idx += 1
      }
      acc
    }
  }

}

规格:

class FactorialSpecs extends SpecHelper {

  private val smallInt = 10
  private val largeInt = 10000

  describe("Factorial.recursive") {
    it("return 1 for 0") {
      assert(Factorial.recursive(0) == 1)
    }
    it("return 1 for 1") {
      assert(Factorial.recursive(1) == 1)
    }
    it("return 2 for 2") {
      assert(Factorial.recursive(2) == 2)
    }
    it("returns a result, for small inputs") {
      assert(Factorial.recursive(smallInt) == 3628800)
    }
    it("throws StackOverflow for large inputs") {
      intercept[java.lang.StackOverflowError] {
        Factorial.recursive(Int.MaxValue)
      }
    }
  }

  describe("Factorial.recursiveTail") {
    it("return 1 for 0") {
      assert(Factorial.recursiveTail(0) == 1)
    }
    it("return 1 for 1") {
      assert(Factorial.recursiveTail(1) == 1)
    }
    it("return 2 for 2") {
      assert(Factorial.recursiveTail(2) == 2)
    }
    it("returns a result, for small inputs") {
      assert(Factorial.recursiveTail(smallInt) == 3628800)
    }
    it("returns a result, for large inputs") {
      assert(Factorial.recursiveTail(largeInt).isInstanceOf[BigInt])
    }
  }

  describe("Factorial.loop") {
    it("return 1 for 0") {
      assert(Factorial.loop(0) == 1)
    }
    it("return 1 for 1") {
      assert(Factorial.loop(1) == 1)
    }
    it("return 2 for 2") {
      assert(Factorial.loop(2) == 2)
    }
    it("returns a result, for small inputs") {
      assert(Factorial.loop(smallInt) == 3628800)
    }
    it("returns a result, for large inputs") {
      assert(Factorial.loop(largeInt).isInstanceOf[BigInt])
    }
  }
}

基准:

import org.scalameter.api._

class BenchmarkFactorials extends Bench.OfflineReport {

  val gen: Gen[Int] = Gen.range("N")(1, 1000, 100) // scalastyle:ignore

  performance of "Factorial" in {
    measure method "loop" in {
      using(gen) in {
        n => Factorial.loop(n)
      }
    }
    measure method "recursive" in {
      using(gen) in {
        n => Factorial.recursive(n)
      }
    }
    measure method "recursiveTail" in {
      using(gen) in {
        n => Factorial.recursiveTail(n)
      }
    }
  }

}

基准测试结果(循环更快):

[info] Test group: Factorial.loop
[info] - Factorial.loop.Test-9 measurements:
[info]   - at N -> 1: passed
[info]     (mean = 0.01 ms, ci = <0.00 ms, 0.02 ms>, significance = 1.0E-10)
[info]   - at N -> 101: passed
[info]     (mean = 0.01 ms, ci = <0.01 ms, 0.01 ms>, significance = 1.0E-10)
[info]   - at N -> 201: passed
[info]     (mean = 0.02 ms, ci = <0.02 ms, 0.02 ms>, significance = 1.0E-10)
[info]   - at N -> 301: passed
[info]     (mean = 0.03 ms, ci = <0.02 ms, 0.03 ms>, significance = 1.0E-10)
[info]   - at N -> 401: passed
[info]     (mean = 0.03 ms, ci = <0.03 ms, 0.04 ms>, significance = 1.0E-10)
[info]   - at N -> 501: passed
[info]     (mean = 0.04 ms, ci = <0.03 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 601: passed
[info]     (mean = 0.04 ms, ci = <0.04 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 701: passed
[info]     (mean = 0.05 ms, ci = <0.05 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 801: passed
[info]     (mean = 0.06 ms, ci = <0.05 ms, 0.06 ms>, significance = 1.0E-10)
[info]   - at N -> 901: passed
[info]     (mean = 0.06 ms, ci = <0.05 ms, 0.07 ms>, significance = 1.0E-10)

[info] Test group: Factorial.recursive
[info] - Factorial.recursive.Test-10 measurements:
[info]   - at N -> 1: passed
[info]     (mean = 0.00 ms, ci = <0.00 ms, 0.01 ms>, significance = 1.0E-10)
[info]   - at N -> 101: passed
[info]     (mean = 0.05 ms, ci = <0.01 ms, 0.09 ms>, significance = 1.0E-10)
[info]   - at N -> 201: passed
[info]     (mean = 0.03 ms, ci = <0.02 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 301: passed
[info]     (mean = 0.07 ms, ci = <0.00 ms, 0.13 ms>, significance = 1.0E-10)
[info]   - at N -> 401: passed
[info]     (mean = 0.09 ms, ci = <0.01 ms, 0.18 ms>, significance = 1.0E-10)
[info]   - at N -> 501: passed
[info]     (mean = 0.10 ms, ci = <0.03 ms, 0.17 ms>, significance = 1.0E-10)
[info]   - at N -> 601: passed
[info]     (mean = 0.11 ms, ci = <0.08 ms, 0.15 ms>, significance = 1.0E-10)
[info]   - at N -> 701: passed
[info]     (mean = 0.13 ms, ci = <0.11 ms, 0.14 ms>, significance = 1.0E-10)
[info]   - at N -> 801: passed
[info]     (mean = 0.16 ms, ci = <0.13 ms, 0.19 ms>, significance = 1.0E-10)
[info]   - at N -> 901: passed
[info]     (mean = 0.21 ms, ci = <0.15 ms, 0.27 ms>, significance = 1.0E-10)

[info] Test group: Factorial.recursiveTail
[info] - Factorial.recursiveTail.Test-11 measurements:
[info]   - at N -> 1: passed
[info]     (mean = 0.00 ms, ci = <0.00 ms, 0.01 ms>, significance = 1.0E-10)
[info]   - at N -> 101: passed
[info]     (mean = 0.04 ms, ci = <0.03 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 201: passed
[info]     (mean = 0.12 ms, ci = <0.05 ms, 0.20 ms>, significance = 1.0E-10)
[info]   - at N -> 301: passed
[info]     (mean = 0.16 ms, ci = <-0.03 ms, 0.34 ms>, significance = 1.0E-10)
[info]   - at N -> 401: passed
[info]     (mean = 0.12 ms, ci = <0.09 ms, 0.16 ms>, significance = 1.0E-10)
[info]   - at N -> 501: passed
[info]     (mean = 0.17 ms, ci = <0.15 ms, 0.19 ms>, significance = 1.0E-10)
[info]   - at N -> 601: passed
[info]     (mean = 0.23 ms, ci = <0.19 ms, 0.26 ms>, significance = 1.0E-10)
[info]   - at N -> 701: passed
[info]     (mean = 0.25 ms, ci = <0.18 ms, 0.32 ms>, significance = 1.0E-10)
[info]   - at N -> 801: passed
[info]     (mean = 0.28 ms, ci = <0.21 ms, 0.36 ms>, significance = 1.0E-10)
[info]   - at N -> 901: passed
[info]     (mean = 0.32 ms, ci = <0.17 ms, 0.46 ms>, significance = 1.0E-10)

答案 2 :(得分:1)

我知道每个人都已经回答了这个问题,但是我想我可以添加这一优化:如果将模式匹配转换为简单的if语句,则可以加快尾部递归的速度。 < / p>

final object Factorial {
  type Out = BigInt

  def calculateByRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    n match {
      case _ if n == 1 => return 1
      case _ => return n * calculateByRecursion(n-1)
    }
  }

  def calculateByForLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var accumulator: Out = 1
    for (i <- 1 to n)
      accumulator = i * accumulator
    accumulator
  }

  def calculateByWhileLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var acc: Out = 1
    var i = 1
    while (i <= n) {
      acc = i * acc
      i += 1
    }
    acc
  }

  def calculateByTailRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    @annotation.tailrec
    def fac(n: Int, acc: Out): Out = if (n==1) acc else fac(n-1, n*acc)

    fac(n, 1)
  }

  def calculateByTailRecursionUpward(n: Int): Out = {
    require(n>0, "n must be positive")

    @annotation.tailrec
    def fac(i: Int, acc: Out): Out = if (i == n) n*acc else fac(i+1, i*acc)

    fac(1, 1)
  }

  def attempt(f: ()=>Unit): Boolean = {
    try {
        f()
        true
    } catch {
        case _: Throwable =>
            println(" <<<<< Failed...")
            false
    }
  }

  def comparePerformance(n: Int) {
    def showOutput[A](msg: String, data: (Long, A), showOutput:Boolean = true) =
      showOutput match {
        case true =>
            val res = data._2.toString
            val pref = res.substring(0,5)
            val midd = res.substring((res.length-5)/ 2, (res.length-5)/ 2 + 10)
            val suff = res.substring(res.length-5)
            printf("%s returned %s in %d ms\n", msg, s"$pref...$midd...$suff" , data._1)
        case false => 
            printf("%s in %d ms\n", msg, data._1)
    }
    def measure[A](f:()=>A): (Long, A) = {
      val start = System.currentTimeMillis
      val o = f()
      (System.currentTimeMillis - start, o)
    }
    attempt(() => showOutput ("By for loop", measure(()=>calculateByForLoop(n))))
    attempt(() => showOutput ("By while loop", measure(()=>calculateByWhileLoop(n))))
    attempt(() => showOutput ("By non-tail recursion", measure(()=>calculateByRecursion(n))))
    attempt(() => showOutput ("By tail recursion", measure(()=>calculateByTailRecursion(n))))
    attempt(() => showOutput ("By tail recursion upward", measure(()=>calculateByTailRecursionUpward(n))))
  }
}

我的结果:

scala> Factorial.comparePerformance(20000)
By for loop returned 18192...5708616582...00000 in 179 ms
By while loop returned 18192...5708616582...00000 in 159 ms
By non-tail recursion <<<<< Failed...
By tail recursion returned 18192...5708616582...00000 in 169 ms
By tail recursion upward returned 18192...5708616582...00000 in 174 ms
By for loop returned 18192...5708616582...00000 in 212 ms
By while loop returned 18192...5708616582...00000 in 156 ms
By non-tail recursion returned 18192...5708616582...00000 in 155 ms
By tail recursion returned 18192...5708616582...00000 in 166 ms
By tail recursion upward returned 18192...5708616582...00000 in 137 ms
scala> Factorial.comparePerformance(200000)
By for loop returned 14202...0169293868...00000 in 17467 ms
By while loop returned 14202...0169293868...00000 in 17303 ms
By non-tail recursion <<<<< Failed...
By tail recursion returned 14202...0169293868...00000 in 18477 ms
By tail recursion upward returned 14202...0169293868...00000 in 17188 ms