我正在使用Python 2.7,并且有一个程序可以解决递归优化问题,即动态编程问题。该代码的简化版本是:
from math import log
from scipy.optimize import minimize_scalar
class vT(object):
def __init__(self,c):
self.c = c
def x(self,w):
return w
def __call__(self,w):
return self.c*log(self.x(w))
class vt(object):
def __init__(self,c,vN):
self.c = c
self.vN = vN
def objFunc(self,x,w):
return -self.c*log(x) - self.vN(w - x)
def x(self,w):
x_star = minimize_scalar(self.objFunc,args=(w,),method='bounded',
bounds=(1e-10,w-1e-10)).x
return x_star
def __call__(self,w):
return self.c*log(self.x(w)) + self.vN(w - self.x(w))
p3 = vT(2.0)
p2 = vt(2.0,p3)
p1 = vt(2.0,p2)
w1 = 3.0
x1 = p1.x(w1)
w2 = w1 - x1
x2 = p2.x(w2)
w3 = w2 - x2
x3 = w3
x = [x1,x2,x3]
print('Optimal x when w1 = 3 is ' + str(x))
如果添加了足够的时间,则该程序可能开始花费很长时间才能运行。运行x1 = p1.x(w1)
时,p2
对p3
和minimize_scalar
进行了多次评估。另外,运行x2 = p2(w2)
时,我们知道最终的解决方案将涉及以第一步中已经完成的方式评估p2
和p3
。
我有两个问题:
vT
和vt
类上使用备忘录包装器来加速该程序的最佳方法是什么?minimize_scalar
时,它会从此备注中受益吗?在我的实际应用程序中,当前解决方案可能要花几个小时才能解决。因此,加快速度将非常有价值。
更新:以下答复指出,上面的示例可以在不使用类的情况下编写,并且正常修饰可以用于函数。在我的实际应用程序中,我必须使用类而不是函数。此外,我的第一个问题是minimize_scalar
内部的函数或方法(当它是一个类时)的调用是否将从备忘录中受益。
答案 0 :(得分:0)
我找到了答案。下面是如何记忆程序的示例。可能有一种更有效的方法,但是这种方法可以记住该类的方法。此外,运行minimize_scalar
时,备忘录包装器每次评估函数时都会记录结果:
from math import log
from scipy.optimize import minimize_scalar
from functools import wraps
def memoize(obj):
cache = obj.cache = {}
@wraps(obj)
def memoizer(*args, **kwargs):
key = str(args) + str(kwargs)
if key not in cache:
cache[key] = obj(*args, **kwargs)
return cache[key]
return memoizer
class vT(object):
def __init__(self,c):
self.c = c
@memoize
def x(self,w):
return w
@memoize
def __call__(self,w):
return self.c*log(self.x(w))
class vt(object):
def __init__(self,c,vN):
self.c = c
self.vN = vN
@memoize
def objFunc(self,x,w):
return -self.c*log(x) - self.vN(w - x)
@memoize
def x(self,w):
x_star = minimize_scalar(self.objFunc,args=(w,),method='bounded',
bounds=(1e-10,w-1e-10)).x
return x_star
@memoize
def __call__(self,w):
return self.c*log(self.x(w)) + self.vN(w - self.x(w))
p3 = vT(2.0)
p2 = vt(2.0,p3)
p1 = vt(2.0,p2)
x1 = p1.x(3.0)
len(p3.x.cache) # how many times was p3.x evaluated?
出[3]:60
x2 = p2.x(3.0 - x1)
len(p3.x.cache) # how many additional times was p3.x evaluated?
出[5]:60