假设循环一个函数以产生数字结果。如果达到最大迭代次数或满足“最佳性”条件,则停止循环。在任何一种情况下,都会输出当前回路的值。获得此结果和停止原因的有效方法是什么?
为说明起见,这是https://www.cs.kent.ac.uk/people/staff/dat/miranda/whyfp90.pdf中4.1中“平方根”示例的Scala实现。
object SquareRootAlg {
def next(a: Double)(x: Double): Double = (x + a/x)/2
def repeat[A](f: A=>A, a: A): Stream[A] = a #:: repeat(f, f(a))
def loopConditional[A](stop: (A, A) => Boolean)(s: => Stream[A] ): A = s match {
case a #:: t if t.isEmpty => a
case a #:: t => if (stop(a, t.head)) t.head else loopConditional(stop)(t)}
}
例如,找到4的平方根:
import SquareRootAlg._
val cond = (a: Double, b: Double) => (a-b).abs < 0.01
val alg = loopConditional(cond) _
val s = repeat(next(4.0), 4.0)
alg(s.take(3)) // = 2.05, "maxIters exceeded"
alg(s.take(5)) // = 2.00000009, "optimality reached"
此代码有效,但没有给出停止的原因。所以 我正在尝试编写方法
def loopConditionalInfo[A](stop: (A, A)=> Boolean)(s: => Stream[A]): (A, Boolean)
在上面的第一种情况下输出(2.05, false)
,在第二种情况下输出(2.00000009, true)
。有没有一种方法可以编写此方法而无需修改next
和repeat
方法?还是其他功能方法会更好?
答案 0 :(得分:4)
通常,您需要返回一个既包含停止原因又包含结果的值。您建议使用(A, Boolean)
返回签名来实现这一目的。
您的代码将变为:
import scala.annotation.tailrec
object SquareRootAlg {
def next(a: Double)(x: Double): Double = (x + a/x)/2
def repeat[A](f: A=>A, a: A): Stream[A] = a #:: repeat(f, f(a))
@tailrec // Checks function is truly tail recursive.
def loopConditional[A](stop: (A, A) => Boolean)(s: => Stream[A] ): (A, Boolean) = {
val a = s.head
val t = s.tail
if(t.isEmpty) (a, false)
else if(stop(a, t.head)) (t.head, true)
else loopConditional(stop)(t)
}
}
答案 1 :(得分:1)
只需返回布尔值而无需修改其他任何内容:
object SquareRootAlg {
def next(a: Double)(x: Double): Double = (x + a/x)/2
def repeat[A](f: A => A, a: A): Stream[A] = a #:: repeat(f, f(a))
def loopConditionalInfo[A]
(stop: (A, A)=> Boolean)
(s: => Stream[A])
: (A, Boolean) = s match {
case a #:: t if t.isEmpty => (a, false)
case a #:: t =>
if (stop(a, t.head)) (t.head, true)
else loopConditionalInfo(stop)(t)
}
}
import SquareRootAlg._
val cond = (a: Double, b: Double) => (a-b).abs < 0.01
val alg = loopConditionalInfo(cond) _
val s = repeat(next(4.0), 4.0)
println(alg(s.take(3))) // = 2.05, "maxIters exceeded"
println(alg(s.take(5)))
打印
(2.05,false)
(2.0000000929222947,true)