如何从递归函数中获取终止原因?

时间:2018-09-20 19:25:24

标签: algorithm scala loops functional-programming stream

假设循环一个函数以产生数字结果。如果达到最大迭代次数或满足“最佳性”条件,则停止循环。在任何一种情况下,都会输出当前回路的值。获得此结果和停止原因的有效方法是什么?

为说明起见,这是https://www.cs.kent.ac.uk/people/staff/dat/miranda/whyfp90.pdf中4.1中“平方根”示例的Scala实现。

object SquareRootAlg {
    def next(a: Double)(x: Double): Double = (x + a/x)/2
    def repeat[A](f: A=>A, a: A): Stream[A] = a #:: repeat(f, f(a))

    def loopConditional[A](stop: (A, A) => Boolean)(s: => Stream[A] ): A = s match {
          case a #:: t  if t.isEmpty => a
          case a #:: t => if (stop(a, t.head)) t.head else loopConditional(stop)(t)}  
  }

例如,找到4的平方根:

import SquareRootAlg._
val cond = (a: Double, b: Double) => (a-b).abs < 0.01
val alg = loopConditional(cond) _
val s = repeat(next(4.0), 4.0)

alg(s.take(3))  // = 2.05, "maxIters exceeded"
alg(s.take(5)) // = 2.00000009, "optimality reached"

此代码有效,但没有给出停止的原因。所以 我正在尝试编写方法

 def loopConditionalInfo[A](stop: (A, A)=> Boolean)(s: => Stream[A]):  (A, Boolean) 

在上面的第一种情况下输出(2.05, false),在第二种情况下输出(2.00000009, true)。有没有一种方法可以编写此方法而无需修改nextrepeat方法?还是其他功能方法会更好?

2 个答案:

答案 0 :(得分:4)

通常,您需要返回一个既包含停止原因又包含结果的值。您建议使用(A, Boolean)返回签名来实现这一目的。

您的代码将变为:

import scala.annotation.tailrec

object SquareRootAlg {
  def next(a: Double)(x: Double): Double = (x + a/x)/2
  def repeat[A](f: A=>A, a: A): Stream[A] = a #:: repeat(f, f(a))

  @tailrec // Checks function is truly tail recursive.
  def loopConditional[A](stop: (A, A) => Boolean)(s: => Stream[A] ): (A, Boolean) = {
    val a = s.head
    val t = s.tail
    if(t.isEmpty) (a, false)
    else if(stop(a, t.head)) (t.head, true)
    else loopConditional(stop)(t)
  }
}

答案 1 :(得分:1)

只需返回布尔值而无需修改其他任何内容:

object SquareRootAlg {
  def next(a: Double)(x: Double): Double = (x + a/x)/2
  def repeat[A](f: A => A, a: A): Stream[A] = a #:: repeat(f, f(a))

  def loopConditionalInfo[A]
    (stop: (A, A)=> Boolean)
    (s: => Stream[A])
  : (A, Boolean) = s match {
    case a #:: t if t.isEmpty => (a, false)
    case a #:: t => 
      if (stop(a, t.head)) (t.head, true) 
      else loopConditionalInfo(stop)(t)
  }
}

import SquareRootAlg._
val cond = (a: Double, b: Double) => (a-b).abs < 0.01
val alg = loopConditionalInfo(cond) _
val s = repeat(next(4.0), 4.0)

println(alg(s.take(3))) // = 2.05, "maxIters exceeded"
println(alg(s.take(5)))

打印

(2.05,false)
(2.0000000929222947,true)