Scala Memoization:这个Scala备忘录是如何工作的?

时间:2014-08-05 00:52:11

标签: scala dynamic-programming memoization

以下代码来自Pathikrit's Dynamic Programming存储库。 我的美丽和特质让我感到迷惑。

def subsetSum(s: List[Int], t: Int) = {
  type DP = Memo[(List[Int], Int), (Int, Int), Seq[Seq[Int]]]
  implicit def encode(key: (List[Int], Int)) = (key._1.length, key._2)

  lazy val f: DP = Memo {
    case (Nil, 0) => Seq(Nil)
    case (Nil, _) => Nil
    case (a :: as, x) => (f(as, x - a) map {_ :+ a}) ++ f(as, x)
  }

  f(s, t)
}

类型Memo在另一个文件中实现:

case class Memo[I <% K, K, O](f: I => O) extends (I => O) {
  import collection.mutable.{Map => Dict}
  val cache = Dict.empty[K, O]
  override def apply(x: I) = cache getOrElseUpdate (x, f(x))
}

我的问题是:

  1. 为什么type K在subsetSum中声明为(Int, Int)

  2. int中的(Int, Int)分别代表什么?

  3. <击> 3。 (List[Int], Int)如何隐式转换为(Int, Int)
    我看不到implicit def foo(x:(List[Int],Int)) = (x._1.toInt,x._2)。 (甚至不在它导入的Implicits.scala文件中。

    *编辑:嗯,我想念这个:

    implicit def encode(key: (List[Int], Int)) = (key._1.length, key._2)
    

    我非常喜欢Pathikrit的图书馆scalgos。里面有很多Scala珍珠。请帮助我,这样我就能体会到Pathikrit的机智。谢谢。 (:

1 个答案:

答案 0 :(得分:53)

我是above code的作者。

/**
 * Generic way to create memoized functions (even recursive and multiple-arg ones)
 *
 * @param f the function to memoize
 * @tparam I input to f
 * @tparam K the keys we should use in cache instead of I
 * @tparam O output of f
 */
case class Memo[I <% K, K, O](f: I => O) extends (I => O) {
  import collection.mutable.{Map => Dict}
  type Input = I
  type Key = K
  type Output = O
  val cache = Dict.empty[K, O]
  override def apply(x: I) = cache getOrElseUpdate (x, f(x))
}

object Memo {
  /**
   * Type of a simple memoized function e.g. when I = K
   */
  type ==>[I, O] = Memo[I, I, O]
}

Memo[I <% K, K, O]

I: input
K: key to lookup in cache
O: output

I <% K表示K可以I viewable(即隐式转换)I

在大多数情况下,K应为fibonacci,例如如果您正在编写Int => Int,这是Int类型的函数,则可以按I进行缓存。

但是,有时当你编写memoization时,你不想总是通过输入本身(K)进行记忆或缓存,而是输入(subsetSum)的函数,例如当你是编写具有(List[Int], Int)类型输入的List[Int]算法,您不希望使用List[Int].size作为缓存中的键,而是希望使用/** * Subset sum algorithm - can we achieve sum t using elements from s? * O(s.map(abs).sum * s.length) * * @param s set of integers * @param t target * @return true iff there exists a subset of s that sums to t */ def isSubsetSumAchievable(s: List[Int], t: Int): Boolean = { type I = (List[Int], Int) // input type type K = (Int, Int) // cache key i.e. (list.size, int) type O = Boolean // output type type DP = Memo[I, K, O] // encode the input as a key in the cache i.e. make K implicitly convertible from I implicit def encode(input: DP#Input): DP#Key = (input._1.length, input._2) lazy val f: DP = Memo { case (Nil, x) => x == 0 // an empty sequence can only achieve a sum of zero case (a :: as, x) => f(as, x - a) || f(as, x) // try with/without a.head } f(s, t) } 作为缓存中的密钥。

所以,这是一个具体案例:

type DP = Memo[(List[Int], Int), (Int, Int), Boolean]

您可以将所有这些缩短为一行: I = K

对于常见情况(type ==>[I, O] = Memo[I, I, O]时),您只需执行此操作: /** * http://mathworld.wolfram.com/Combination.html * @return memoized function to calculate C(n,r) */ val c: (Int, Int) ==> BigInt = Memo { case (_, 0) => 1 case (n, r) if r > n/2 => c(n, n - r) case (n, r) => c(n - 1, r - 1) + c(n - 1, r) } 并通过递归memoization将其用于calculate the binomial coeff

(Seq, Seq)

要详细了解上述语法的工作原理,请refer to this question

以下是一个完整示例,通过将输入(Seq.length, Seq.length)的参数编码为 /** * Calculate edit distance between 2 sequences * O(s1.length * s2.length) * * @return Minimum cost to convert s1 into s2 using delete, insert and replace operations */ def editDistance[A](s1: Seq[A], s2: Seq[A]) = { type DP = Memo[(Seq[A], Seq[A]), (Int, Int), Int] implicit def encode(key: DP#Input): DP#Key = (key._1.length, key._2.length) lazy val f: DP = Memo { case (a, Nil) => a.length case (Nil, b) => b.length case (a :: as, b :: bs) if a == b => f(as, bs) case (a, b) => 1 + (f(a, b.tail) min f(a.tail, b) min f(a.tail, b.tail)) } f(s1, s2) } 来计算editDistance

lazy val fib: Int ==> BigInt = Memo {
  case 0 => 0
  case 1 => 1
  case n if n > 1 => fib(n-1) + fib(n-2)
}

println(fib(100))

最后,规范的斐波那契例子:

{{1}}