由TailRec和State monad

时间:2017-12-20 13:39:02

标签: scala monads

为了简化我的问题,我将从一个学术范例开始,即ackermann函数。

我使用以下递归天真实现:

def a(m: BigInt, n: BigInt): BigInt = {
  if (m == 0) {
    n + 1
  } else if (m > 0 && n == 0) a(m - 1, 1)
  else a(m - 1, a(m, n - 1))
}

这不是最佳的,并且在堆栈溢出中快速结束。 所以我构建了一个新的实现,它使用标准scala库中的TailRec,并给出了:

import scala.util.control.TailCalls._

private[this] def a_impl(m: BigInt, n: BigInt): TailRec[BigInt] = {
  if (m == 0) {
    done(n + 1)
  } else if (m > 0 && n == 0) tailcall(a_impl(m - 1, 1))
  else
    for {
      x <- tailcall(a_impl(m, n - 1))
      y <- tailcall(a_impl(m - 1, x))
    } yield y

}

def a(m: BigInt, n: BigInt): BigInt = {
  a_impl(m, n).result
}

它有效,但速度很慢。 所以我构建了一个使用State monad的新实现,但我又失去了终端递归。

type Memo = Map[(BigInt, BigInt), BigInt]

private[this] def a_impl(m: BigInt, n: BigInt): State[Memo, BigInt] = {
  if (m == 0) {
    State.init(n + 1)
  } else {
    for {
      memoed <- State.gets { memo: Memo => memo get (m, n) }
      res <- memoed match {
        case Some(ack) => State.init[Memo, BigInt](ack)
        case None =>
          if (m > 0 && n == 0) for {
            a <- a_impl(m - 1, 1)
            _ <- State.update { memo: Memo => memo + ((m, n) -> a) }
          } yield a
          else for {
            a <- a_impl(m, n - 1)
            b <- a_impl(m - 1, a)
            _ <- State.update { memo: Memo => memo + ((m, n) -> b) }
          } yield b
      }
    } yield res
  }
}

def a(m: BigInt, n: BigInt): BigInt = {
  a_impl(m, n) eval (Map())
}

所以我的问题是,如何同时使用State和TailRec?

我已经看过Monad Transformer的概念,但我不知道如何在我的例子中使用它。 我甚至不知道使用哪种类型,我可以选择它和它之间:

type TailRecWithState = TailRec[State[Memo, BigInt]] 
// or  
type StateWithTailRec = State[Memo, TailRec[BigInt]]

你能帮助我并指出我在这个例子上的正确方向(我会根据我的实际案例进行管理)吗?

1 个答案:

答案 0 :(得分:3)

至少在cat中我知道,State[S, A]StateT[Eval, S, A]的类型别名,其中EvalTailRec类似,完全符合您的需要 - 堆栈安全延迟执行。这对我来说很好用:

import cats._, cats.data._, cats.implicits._

type Memo = Map[(BigInt, BigInt), BigInt]

private[this] def a_impl(m: BigInt, n: BigInt): State[Memo, BigInt] = {
  if (m == 0) {
    State.pure(n + 1)
  } else {
    for {
      memoed <- State.inspect[Memo, Option[BigInt]](s => s.get((m, n)))
      res <- memoed match {
        case Some(x) => State.pure[Memo, BigInt](x)
        case None => {
          if (n == 0) for {
            a <- a_impl(m - 1, 1)
            _ <- State.modify[Memo](s => s + ((m, n) -> a))
          } yield a
          else for {
            a <- a_impl(m, n - 1)
            b <- a_impl(m - 1, a)
            _ <- State.modify[Memo](s => s + ((m, n) -> b))
          } yield b
        }
      }
    } yield res
  }
}

def a(m: BigInt, n: BigInt): BigInt = {
  a_impl(m, n).runA(Map()).value
}

我的猜测是scalaz也可能有类似的StateTEval,但我对这个库并不熟悉。