Pascal的Trical Scala:使用尾部递归方法计算Pascal的三角形元素

时间:2019-10-16 11:35:02

标签: algorithm scala recursion tail-recursion

在Pascal三角形中,三角形边缘的数字均为1,三角形内部的每个数字均为其上方两个数字的和。帕斯卡三角形的示例如下所示。

    1
   1 1
  1 2 1
 1 3 3 1
1 4 6 4 1

我编写了一个程序,使用以下技术计算帕斯卡三角形的元素。

/**
* Can I make it tail recursive???
*
* @param c column
* @param r row
* @return 
*/
def pascalTriangle(c: Int, r: Int): Int = {
  if (c == 0 || (c == r)) 1
  else
    pascalTriangle(c-1, r-1) + pascalTriangle(c, r - 1)
}

例如,如果

i/p: pascalTriangle(0,2)  
o/p: 1.

i/p: pascalTriangle(1,3)  
o/p: 3.

以上程序正确无误,并提供了预期的正确输出。我的问题是,是否可以编写上述算法的尾部递归版本?怎么样?

2 个答案:

答案 0 :(得分:1)

尝试

service apache2 reload
apache2 graceful

def pascalTriangle(c: Int, r: Int): Int = {
  @tailrec
  def loop(c0: Int, r0: Int, pred: Array[Int], cur: Array[Int]): Int = {
    cur(c0) = (if (c0 > 0) pred(c0 - 1) else 0) + (if (c0 < r0) pred(c0) else 0)

    if ((c0 == c) && (r0 == r)) cur(c0)
    else if (c0 < r0) loop(c0 + 1, r0, pred, cur)
    else loop(0, r0 + 1, cur, new Array(_length = r0 + 2))
  }

  if ((c == 0) && (r == 0)) 1
  else loop(0, 1, Array(1), Array(0, 0))
}

import scala.util.control.TailCalls._

def pascalTriangle(c: Int, r: Int): Int = {
  def hlp(c: Int, r: Int): TailRec[Int] =
    if (c == 0 || (c == r)) done(1)
    else for {
      x <- tailcall(hlp(c - 1, r - 1))
      y <- tailcall(hlp(c, r - 1))
    } yield (x + y)

  hlp(c, r).result
}

http://eed3si9n.com/herding-cats/stackless-scala-with-free-monads.html

答案 1 :(得分:0)

对@DmytroMitin 的第一个解决方案的一些优化:

  1. if ((c == 0) && (r == 0)) 1 替换为 if (c == 0 || c == r) 1
  2. 利用三角形的对称性,如果 c 大于 r 的一半,则使用 c 的反射值。

通过这些优化,对 loop 绘制 30 行三角形的调用次数从 122,760 次减少到 112,375 次(使用 #1)减少到 110,240 次(使用 #1 和 #2)调用(没有记忆化) ).

  def pascalTail(c: Int, r: Int): Int = {
    val cOpt = if (c > r/2) r - c else c
    def loop(col: Int, row: Int, previous: Array[Int], current: Array[Int]): Int = {
      current(col) = (if (col > 0) previous(col - 1) else 0) + (if (col < row) previous(col) else 0)

      if ((col == cOpt) && (row == r)) current(col)
      else if (col < row) loop(col + 1, row, previous, current)
      else loop(0, row + 1, current, new Array(_length = row + 2))
    }

    if (c == 0 || c == r) 1
    else loop(0, 1, Array(1), new Array(_length = 2))
  }