免费〜> Trampoline:递归程序崩溃与OutOfMemoryError

时间:2016-12-13 14:01:09

标签: scala out-of-memory scalaz tail-recursion free-monad

假设我尝试使用一个操作来实现一个非常简单的特定于域的语言:

printLine(line)

然后我想编写一个以整数n作为输入的程序,如果n可被10k整除则输出一些内容,然后用n + 1调用自身,直到{{1达到某个最大值n

省略由for-comprehensions引起的所有语法噪音,我想要的是:

N

基本上,它会是一种" fizzbuzz"。

以下是使用Scalaz 7.3.0-M7中的免费monad实现此操作的一些尝试:

@annotation.tailrec def p(n: Int): Unit = {
  if (n % 10000 == 0) printLine("line")
  if (n > N) () else p(n + 1)
}

不幸的是,简单的翻译(import scalaz._ object Demo1 { // define operations of a little domain specific language sealed trait Lang[X] case class PrintLine(line: String) extends Lang[Unit] // define the domain specific language as the free monad of operations type Prog[X] = Free[Lang, X] import Free.{liftF, pure} // lift operations into the free monad def printLine(l: String): Prog[Unit] = liftF(PrintLine(l)) def ret: Prog[Unit] = Free.pure(()) // write a program that is just a loop that prints current index // after every few iteration steps val mod = 100000 val N = 1000000 // straightforward syntax: deadly slow, exits with OutOfMemoryError def p0(i: Int): Prog[Unit] = for { _ <- (if (i % mod == 0) printLine("i = " + i) else ret) _ <- (if (i > N) ret else p0(i + 1)) } yield () // Same as above, but written out without `for` def p1(i: Int): Prog[Unit] = (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{ ignore1 => (if (i > N) ret else p1(i + 1)).map{ ignore2 => () } } // Same as above, with `map` attached to recursive call def p2(i: Int): Prog[Unit] = (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{ ignore1 => (if (i > N) ret else p2(i + 1).map{ ignore2 => () }) } // Same as above, but without the `map`; performs ok. def p3(i: Int): Prog[Unit] = { (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{ ignore1 => if (i > N) ret else p3(i + 1) } } // Variation of the above; Ok. def p4(i: Int): Prog[Unit] = (for { _ <- (if (i % mod == 0) printLine("i = " + i) else ret) } yield ()).flatMap{ ignored2 => if (i > N) ret else p4(i + 1) } // try to use the variable returned by the last generator after yield, // hope that the final `map` is optimized away (it's not optimized away...) def p5(i: Int): Prog[Unit] = for { _ <- (if (i % mod == 0) printLine("i = " + i) else ret) stopHere <- (if (i > N) ret else p5(i + 1)) } yield stopHere // define an interpreter that translates the programs into Trampoline import scalaz.Trampoline type Exec[X] = Free.Trampoline[X] val interpreter = new (Lang ~> Exec) { def apply[A](cmd: Lang[A]): Exec[A] = cmd match { case PrintLine(l) => Trampoline.delay(println(l)) } } // try it out def main(args: Array[String]): Unit = { println("\n p0") p0(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError println("\n p1") p1(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError println("\n p2") p2(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError println("\n p3") p3(0).foldMap(interpreter).run // ok println("\n p4") p4(0).foldMap(interpreter).run // ok println("\n p5") p5(0).foldMap(interpreter).run // OutOfMemory } } )似乎与某种O(N ^ 2)开销一起运行,并且与OutOfMemoryError崩溃。问题似乎是p0 - 理解在递归调用for之后追加map{x => ()},这迫使p0 monad用提醒来填充整个内存。 #34;完成&#39; p0&#39;然后什么都不做&#34;。 如果我手动&#34;展开&#34; Free理解,并明确地写出最后的for(如flatMapp3),然后问题消失,一切顺利进行。但是,这是一个非常脆弱的解决方法:如果我们只是向p4添加map(id),程序的行为会发生显着变化,而map(id)在代码中甚至不可见,因为它由for - 理解自动生成。

在这篇较旧的帖子中:https://apocalisp.wordpress.com/2011/10/26/tail-call-elimination-in-scala-monads/ 已反复建议将递归调用包装到suspend中。以下是Applicative实例和suspend的尝试:

import scalaz._

// Essentially same as in `Demo1`, but this time with 
// an `Applicative` and an explicit `Suspend` in the 
// `for`-comprehension
object Demo2 {

  sealed trait Lang[H]

  case class Const[H](h: H) extends Lang[H]
  case class PrintLine[H](line: String) extends Lang[H]

  implicit object Lang extends Applicative[Lang] {
    def point[A](a: => A): Lang[A] = Const(a)
    def ap[A, B](a: => Lang[A])(f: => Lang[A => B]): Lang[B] = a match {
      case Const(x) => {
        f match {
          case Const(ab) => Const(ab(x))
          case _ => throw new Error
        }
      }
      case PrintLine(l) => PrintLine(l)
    }
  }

  type Prog[X] = Free[Lang, X]

  import Free.{liftF, pure}
  def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
  def ret: Prog[Unit] = Free.pure(())

  val mod = 100000
  val N = 2000000

  // try to suspend the entire second generator
  def p7(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- Free.suspend(if (i > N) ret else p7(i + 1))
  } yield ()

  // try to suspend the recursive call
  def p8(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- if (i > N) ret else Free.suspend(p8(i + 1))
  } yield ()

  import scalaz.Trampoline
  type Exec[X] = Free.Trampoline[X]

  val interpreter = new (Lang ~> Exec) {
    def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
      case Const(x) => Trampoline.done(x)
      case PrintLine(l) => 
        (Trampoline.delay(println(l))).asInstanceOf[Exec[A]]
    }
  }

  def main(args: Array[String]): Unit = {
    p7(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    p8(0).foldMap(interpreter).run // same...
  }
}

插入suspend并没有什么帮助:它仍然很慢,并且与OutOfMemoryError一起崩溃。

我应该以某种方式使用suspend吗?

也许有一些纯粹的句法补救措施可以在不产生map的情况下使用for-comprehensions吗?

如果有人能指出我在这里做错了什么,以及如何修复它,我真的很感激。

1 个答案:

答案 0 :(得分:3)

Scala编译器添加的多余map将递归从尾部位置移动到非尾部位置。免费monad仍然使这个堆栈安全,但空间复杂性变为 O(N)而不是 O(1)。 (具体来说,它仍然不是 O(N 2 。)

是否有可能scalac优化map距离会产生一个单独的问题(我不知道答案)。

我将尝试说明在解释p1p3时发生的情况。 (我将忽略Trampoline的翻译,这是多余的(见下文)。)

p3(即没有额外map

让我用以下简写:

def cont(i: Int): Unit => Prg[Unit] =
  ignore1 => if (i > N) ret else p3(i + 1)

现在p3(0)解释如下

p3(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p3(1)
ret flatMap cont(1)
cont(1)
p3(2)
ret flatMap cont(2)
cont(2)

等等......你看到任何一点所需的内存量都不会超过某个常数上限。

p1(即附加map

我将使用以下缩写:

def cont(i: Int): Unit => Prg[Unit] =
  ignore1 => (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }

def cpu: Unit => Prg[Unit] = // constant pure unit
  ignore => Free.pure(())

现在p1(0)解释如下:

p1(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p1(1) map { ignore2 => () }
// Free.map is implemented via flatMap
p1(1) flatMap cpu
(ret flatMap cont(1)) flatMap cpu
cont(1) flatMap cpu
(p1(2) map { ignore2 => () }) flatMap cpu
(p1(2) flatMap cpu) flatMap cpu
((ret flatMap cont(2)) flatMap cpu) flatMap cpu
(cont(2) flatMap cpu) flatMap cpu
((p1(3) map { ignore2 => () }) flatMap cpu) flatMap cpu
((p1(3) flatMap cpu) flatMap cpu) flatMap cpu
(((ret flatMap cont(3)) flatMap cpu) flatMap cpu) flatMap cpu

等等......你看到内存消耗线性地取决于N。我们只是将评估从堆栈移动到堆。

带走:要保持Free内存友好,请将递归保留在&#34;尾部位置&#34;,即{{1}的右侧(或flatMap)。

旁白:由于map已经贬值,因此无需翻译Trampoline。您可以直接解释为Free并使用Id进行堆栈安全解释:

foldMapRec

这会让你恢复一些记忆(但不会让问题消失)。