添加memoization到递归背包解决方案

时间:2013-11-13 18:19:59

标签: algorithm scala knapsack-problem

我正在玩背包(没有真正的理由,只是试图去除一些生锈)并想用我最喜欢的语言实现它

(请不要笑,从大学开始已经有一段时间了,而且我对Scala很新)

这是我的第一次运行(它返回正确的解决方案,但我认为它远非最佳):

import scala.collection.mutable.HashMap

object Main {
  def main(args: Array[String]) {
    val weights = List(23, 31, 29, 44, 53, 38, 63, 85, 89, 82)
    val values = List(92, 57, 49, 68, 60, 43, 67, 84, 87, 72)
    val wv = weights zip values
    val solver = new KnapSackSolver()
    solver.solve(wv, 165) 

  }

  class KnapSackSolver() {

    var numberOfIterations = 0
    type Item = (Int, Int)
    type Items = List[Item]

    val cache = new HashMap[(Items, Int), Items]()

    def sackValue(s: Items) = if (s.isEmpty) 0 else s.map(_._2).sum

    def solve(wv: Items, capacity: Int) = {
      numberOfIterations = 0
      val solution = knapsack(wv, capacity)

      println(s"""|Solution: $solution
                  |Value: ${sackValue(solution)}
                  |Number of iterations: $numberOfIterations
      """.stripMargin)

      solution 
    }

    private[this] def knapsack(wv: Items, capacity: Int): Items = {
      numberOfIterations +=1
      val cacheKey = (wv, capacity)
      if (cache.contains(cacheKey)) {
        return cache(cacheKey) //I know, I wrote a return, just wanted an early exit
      }

      if (capacity <= 0 || wv.isEmpty) {
        Nil
      } else if (wv.head._1 > capacity) {
        knapsack(wv.tail, capacity)
      } else {
        val sackNotTakingCurrent = knapsack(wv.tail, capacity)
        val sackTakingCurrent = knapsack(wv.tail, capacity - wv.head._1) :+ wv.head

        val notTakingCurrentValue = sackValue(sackNotTakingCurrent)
        val takingCurrentValue = sackValue(sackTakingCurrent)
        val ret =
          if (notTakingCurrentValue >= takingCurrentValue) sackNotTakingCurrent
          else sackTakingCurrent

        cache(cacheKey) = ret
        ret
      }
    }
  }
}

问题

我天真的“缓存”似乎不够好(565 vs 534次迭代)但我不确定如何改进它,我有一种感觉,填充一个大小的项目矩阵的矩阵涉及,但不知道从哪里去这里。

换句话说 - 这是最佳解决方案吗?对我来说感觉非常指数,但如果我说我理解伪多项式的真正含义,那么我会说谎... 如果这不是最佳解决方案,我怀疑它不是,那么我错过了什么?

2 个答案:

答案 0 :(得分:0)

我想我发现了我的问题,忘了缓存一些情况,这是我的解决方案,下至472次迭代(1650次迭代方法,即N * W)

class KnapSackSolver() {

    var numberOfIterations = 0
    type Item = (Int, Int)
    type Items = List[Item]

    val cache = new HashMap[(Items, Int), Items]()

    def sackValue(s: Items) = if (s.isEmpty) 0 else s.map(_._2).sum

    def solve(wv: Items, capacity: Int) = {
      numberOfIterations = 0
      val solution = knapsack(wv, capacity)

      println(s"""|Solution: $solution
                  |Value: ${sackValue(solution)}
                  |Number of iterations: $numberOfIterations
      """.stripMargin)

      solution
    }

    private[this] def knapsack(wv: Items, capacity: Int): Items = {
      numberOfIterations += 1
      val cacheKey = (wv, capacity)
      if (cache.contains(cacheKey)) {
        cache(cacheKey)
      } else {
        val ret =
          if (capacity <= 0 || wv.isEmpty) {
            Nil
          } else if (wv.head._1 > capacity) {
            knapsack(wv.tail, capacity)
          } else {
            val sackNotTakingCurrent = knapsack(wv.tail, capacity)
            val sackTakingCurrent = wv.head :: knapsack(wv.tail, capacity - wv.head._1)

            val notTakingCurrentValue = sackValue(sackNotTakingCurrent)
            val takingCurrentValue = sackValue(sackTakingCurrent)
            if (notTakingCurrentValue >= takingCurrentValue) sackNotTakingCurrent
            else sackTakingCurrent

          }
        cache(cacheKey) = ret
        ret
      }

    }
  }

答案 1 :(得分:0)

这是scala代码,可以解释你的概念(它没有在REPL上测试,但会为你提供如何处理背包问题的直觉)

def KnapsackSolver(Weights: List[Int],Values: List[Int],Capacity: Int): (Int,List[Int]) {

    val cache = new HashMap((Int,Int),Int)()

    def solve(W: List[Int],V: List[Int],C: Int) : Int  = {

       if(W.Length<1)
          0
       else {

         val currV = V.head
         val currW = W.head
         val Key1 = (W.length-1,C)
         val Key2 = (W.length-1,C-currW)
         val sum1 = 
            if(cache.containsKey(Key1)) 
                 cache(Key1)
             else solve(W.tail,V.tail,C)
         val sum2 = 
           if(currW<=C) {
              if(cache.containsKey(Key2))
                 cache(Key2)
              else solve(W.tail,V.tail,C-currW) + currV
            }
            else 0 
         cache((W.length,C)) = math.max(sum1,sum2)
         math.max(sum1,sum2)



       }
    }

    def traceSol(C: Int,W: List[Int]): List[Int] = {

         if(W.Length<1)
           nil
         else {
             val sum1 = cache((W.Length-1,C))
             val sum2 = cache((W.Length,C)) 
             if(sum1==sum2)
                traceSol(C,W.tail)
             else W.Length :: traceSol(C-W.head,W.tail)
         } 

    }

   val optval = solve(Weights,Values,Capacity)
   val solution = traceSol(Capacity,Weights)
   (optval,solution)

}