我正在玩背包(没有真正的理由,只是试图去除一些生锈)并想用我最喜欢的语言实现它
(请不要笑,从大学开始已经有一段时间了,而且我对Scala很新)
这是我的第一次运行(它返回正确的解决方案,但我认为它远非最佳):
import scala.collection.mutable.HashMap
object Main {
def main(args: Array[String]) {
val weights = List(23, 31, 29, 44, 53, 38, 63, 85, 89, 82)
val values = List(92, 57, 49, 68, 60, 43, 67, 84, 87, 72)
val wv = weights zip values
val solver = new KnapSackSolver()
solver.solve(wv, 165)
}
class KnapSackSolver() {
var numberOfIterations = 0
type Item = (Int, Int)
type Items = List[Item]
val cache = new HashMap[(Items, Int), Items]()
def sackValue(s: Items) = if (s.isEmpty) 0 else s.map(_._2).sum
def solve(wv: Items, capacity: Int) = {
numberOfIterations = 0
val solution = knapsack(wv, capacity)
println(s"""|Solution: $solution
|Value: ${sackValue(solution)}
|Number of iterations: $numberOfIterations
""".stripMargin)
solution
}
private[this] def knapsack(wv: Items, capacity: Int): Items = {
numberOfIterations +=1
val cacheKey = (wv, capacity)
if (cache.contains(cacheKey)) {
return cache(cacheKey) //I know, I wrote a return, just wanted an early exit
}
if (capacity <= 0 || wv.isEmpty) {
Nil
} else if (wv.head._1 > capacity) {
knapsack(wv.tail, capacity)
} else {
val sackNotTakingCurrent = knapsack(wv.tail, capacity)
val sackTakingCurrent = knapsack(wv.tail, capacity - wv.head._1) :+ wv.head
val notTakingCurrentValue = sackValue(sackNotTakingCurrent)
val takingCurrentValue = sackValue(sackTakingCurrent)
val ret =
if (notTakingCurrentValue >= takingCurrentValue) sackNotTakingCurrent
else sackTakingCurrent
cache(cacheKey) = ret
ret
}
}
}
}
问题
我天真的“缓存”似乎不够好(565 vs 534次迭代)但我不确定如何改进它,我有一种感觉,填充一个大小的项目矩阵的矩阵涉及,但不知道从哪里去这里。
换句话说 - 这是最佳解决方案吗?对我来说感觉非常指数,但如果我说我理解伪多项式的真正含义,那么我会说谎... 如果这不是最佳解决方案,我怀疑它不是,那么我错过了什么?
答案 0 :(得分:0)
我想我发现了我的问题,忘了缓存一些情况,这是我的解决方案,下至472次迭代(1650次迭代方法,即N * W)
class KnapSackSolver() {
var numberOfIterations = 0
type Item = (Int, Int)
type Items = List[Item]
val cache = new HashMap[(Items, Int), Items]()
def sackValue(s: Items) = if (s.isEmpty) 0 else s.map(_._2).sum
def solve(wv: Items, capacity: Int) = {
numberOfIterations = 0
val solution = knapsack(wv, capacity)
println(s"""|Solution: $solution
|Value: ${sackValue(solution)}
|Number of iterations: $numberOfIterations
""".stripMargin)
solution
}
private[this] def knapsack(wv: Items, capacity: Int): Items = {
numberOfIterations += 1
val cacheKey = (wv, capacity)
if (cache.contains(cacheKey)) {
cache(cacheKey)
} else {
val ret =
if (capacity <= 0 || wv.isEmpty) {
Nil
} else if (wv.head._1 > capacity) {
knapsack(wv.tail, capacity)
} else {
val sackNotTakingCurrent = knapsack(wv.tail, capacity)
val sackTakingCurrent = wv.head :: knapsack(wv.tail, capacity - wv.head._1)
val notTakingCurrentValue = sackValue(sackNotTakingCurrent)
val takingCurrentValue = sackValue(sackTakingCurrent)
if (notTakingCurrentValue >= takingCurrentValue) sackNotTakingCurrent
else sackTakingCurrent
}
cache(cacheKey) = ret
ret
}
}
}
答案 1 :(得分:0)
这是scala代码,可以解释你的概念(它没有在REPL上测试,但会为你提供如何处理背包问题的直觉)
def KnapsackSolver(Weights: List[Int],Values: List[Int],Capacity: Int): (Int,List[Int]) {
val cache = new HashMap((Int,Int),Int)()
def solve(W: List[Int],V: List[Int],C: Int) : Int = {
if(W.Length<1)
0
else {
val currV = V.head
val currW = W.head
val Key1 = (W.length-1,C)
val Key2 = (W.length-1,C-currW)
val sum1 =
if(cache.containsKey(Key1))
cache(Key1)
else solve(W.tail,V.tail,C)
val sum2 =
if(currW<=C) {
if(cache.containsKey(Key2))
cache(Key2)
else solve(W.tail,V.tail,C-currW) + currV
}
else 0
cache((W.length,C)) = math.max(sum1,sum2)
math.max(sum1,sum2)
}
}
def traceSol(C: Int,W: List[Int]): List[Int] = {
if(W.Length<1)
nil
else {
val sum1 = cache((W.Length-1,C))
val sum2 = cache((W.Length,C))
if(sum1==sum2)
traceSol(C,W.tail)
else W.Length :: traceSol(C-W.head,W.tail)
}
}
val optval = solve(Weights,Values,Capacity)
val solution = traceSol(Capacity,Weights)
(optval,solution)
}