重构一个方法来修复尾递归调用不在尾部位置

时间:2015-03-29 18:36:32

标签: scala tail-recursion

考虑以下递归幂方法乘法:

import scala.annotation.tailrec
@tailrec def mult(x: Double, n:Int) : Double = {
      n match {
    case 0 => 1
    case 1 => x
    case _ if ((n & 0x01) != 0) =>  x * mult(x*x, (n-1)/2)
    case _ =>  mult(x*x, n/2)
    }
}

编译错误是:

<console>:28: error: could not optimize @tailrec annotated method mult: 
it contains a recursive call not in tail position
             y *  mult(x*x,(n-2)/2)
               ^

所以..鉴于递归调用最后一个条目 - 我认为产品y *(尾递归子句)存在问题?如何正确地构建这个?

更新

以下是已接受答案的修改版本 - 其中我很懒,只是在被调用的方法中放了第三个累加器。

@tailrec def mult(x: Double, n:Int, accum: Double = 1.0) : Double = {
        n match {
      case 0 => accum
      case 1 => accum * x
      case _ if ((n & 0x01) != 0) =>  mult(x*x, (n-1)/2, x * accum)
      case _ =>  mult(x*x, n/2, accum)
      }
  }

mult: (x: Double, n: Int, accum: Double)Double

尝试一下:

scala> mult(2, 7)
res0: Double = 128.0

scala> mult(2, 8)
res1: Double = 256.0

3 个答案:

答案 0 :(得分:4)

有两种方法可以解决这类问题。第一种是在调用中移动乘法,可能是通过添加辅助方法:

import scala.annotation.tailrec

def mult(x: Double, n: Int): Double = {
  @tailrec
  def go(x: Double, n: Int, mult: Double): Double = n match {
    case 0 => mult
    case 1 => mult * x
    case _ if (n & 0x01) != 0 => go(x * x, (n - 1) / 2, x * mult)
    case _ => go(x * x, n / 2, mult)
  }
  go(x, n, 1)
}

另一个字面上并不是你问题的答案,但在某些情况下它可能是一种更方便的方法。它被称为&#34; trampolining&#34;:

import scala.util.control.TailCalls._

def mult(x: Double, n: Int): Double = {
  def go(x: Double, n: Int): TailRec[Double] = n match {
    case 0 => done(1)
    case 1 => done(x)
    case _ if (n & 0x01) != 0 => tailcall(go(x * x, (n - 1) / 2).map(_ * x))
    case _ => tailcall(go(x * x, n / 2))
  }
  go(x, n).result
}

这并不要求您重新构建方法,并且保证不会破坏堆栈,但它会引入一些额外的开销。

答案 1 :(得分:1)

尾递归调用是那些最后一个语句只是函数调用本身的调用。 你的代码的最后陈述应仅为mult(x*x,(n-2)/2)

你可以试试这个。

import scala.annotation.tailrec
  @tailrec
  def mult(x: Double, n:Int,res:Double=1) : Double = {
    n match {
      case 0 => res
      case _ => mult(x,n-1,res *x)
    }
  }

答案 2 :(得分:0)

您的函数mult不是尾递归的,因为在函数体中您想要对递归调用的结果执行某些操作,即您希望将其与y相乘。

要使此尾递归,您应该构造函数mult,以便它可以将值y作为参数来移除递归调用后的乘法。以下是使用阶乘的简单示例:http://c2.com/cgi/wiki?TailRecursion