由于我将此代码调用超过1000次,有没有办法优化此代码需要1.73秒?
def generate():
S0 = 0
T = 1.
nt = 100000
lbd = 500.
mu = 0
sigma = 1.
dt = T/nt
St = [S0] * nt
sqrtdt = np.sqrt(dt)
dBt = np.random.normal(0, sqrtdt, nt)
for k in xrange(1, nt):
dSt = lbd * (mu - St[k-1]) * dt + sigma * dBt[k]
St[k] = St[k-1] + dSt
return St
答案 0 :(得分:5)
你可以从for-loop
中挤出更多的工作,但同时生成所有路径(假设你有足够的内存):
import numpy as np
def generate_orig(T=1., nt=100000, lbd=500., mu=0, sigma=1., S0=0):
dt = T/nt
St = [S0] * nt
sqrtdt = np.sqrt(dt)
dBt = np.random.normal(0, sqrtdt, nt)
for k in xrange(1, nt):
dSt = lbd * (mu - St[k-1]) * dt + sigma * dBt[k]
St[k] = St[k-1] + dSt
return St
def generate(T=1., nt=100000, lbd=500., mu=0, sigma=1., S0=0, npaths=1):
dt = T/nt
St = np.full((nt, npaths), S0)
sqrtdt = np.sqrt(dt)
dBt = np.random.normal(0, sqrtdt, size=(nt, npaths))
for k in xrange(1, nt):
dSt = lbd * (mu - St[k-1]) * dt + sigma * dBt[k]
St[k] = St[k-1] + dSt
return St
以下是100条路径的时间基准。
In [55]: %timeit [generate_orig() for i in xrange(100)]
1 loops, best of 3: 23.6 s per loop
In [56]: %timeit generate(npaths=100)
1 loops, best of 3: 1.97 s per loop
您也可以通过使用Cython来提高for-loop
的性能。
答案 1 :(得分:1)
我想为unutbu's answer提供两种替代解决方案。他写的是 The Right Thing™,如果您不想依赖Cython或JIT编译器和,那么生成输出{ {1}}批量生产。
我从他的回答中抓取St
并将Python列表generate_orig()
变成了一个numpy数组:
St
时序:
import numpy as np
def generate_orig(T=1., nt=100000, lbd=500., mu=0, sigma=1., S0=0):
dt = T/nt
St = np.full(nt, fill_value=S0, dtype=np.float64)
sqrtdt = np.sqrt(dt)
dBt = np.random.normal(0, sqrtdt, nt)
for k in xrange(1, nt):
dSt = lbd * (mu - St[k-1]) * dt + sigma * dBt[k]
St[k] = St[k-1] + dSt
return St
到目前为止没有任何改进,与以前一样。但是,使用Numba,只需添加%timeit [generate_orig() for i in xrange(100)]
1 loops, best of 3: 25.4 s per loop
:
@autojit
时间下降:
import numpy as np
from numba import autojit
@autojit
def generate_orig(T=1., nt=100000, lbd=500., mu=0, sigma=1., S0=0):
# The rest is exactly the same as before
我认为真棒! 40倍加速只是为了添加%timeit [generate_orig(1., 100000, 500., 0, 1., 0) for i in xrange(100)]
1 loops, best of 3: 642 ms per loop
!
这是一个带typed memoryviews的Cython版本:
@autojit
定时:
%%cython
# cython: infer_types=True
# cython: boundscheck=False
# cython: wraparound=False
import numpy as np
cimport numpy as np
def generate_cython(double T=1., int nt=100000, double lbd=500., double mu=0, double sigma=1., double S0=0):
cdef int k
cdef double dt, dSt
cdef double[:] vSt, vdBt
dt = T/nt
St = np.full(nt, fill_value=S0, dtype=np.float64)
vSt = St
vdBt = np.random.normal(0.0, np.sqrt(dt), nt)
for k in xrange(1, nt):
dSt = lbd * (mu - vSt[k-1]) * dt + sigma * vdBt[k]
vSt[k] = vSt[k-1] + dSt
return St
代码与Numba版本完全一样快(微小差别只是噪音)。但是代码变得丑陋,所有那些类型的声明都让它变得笨拙。 :(好吧,不是灾难,但仍然。
与unutbu的答案相比,这两种解决方案的速度都提高了3倍,在我的机器上运行时间为1.97秒。然而,正如我在开始时所说的那样,如果你不想依赖Cython或Numba,他的解决方案就是你的选择。 (它们都有缺点;如果有人想要避免这种依赖,这是可以理解的。)
如果我们将Numba或Cython应用于unutbu的解决方案会怎样?这会导致代码更快吗?不,与Numba一样,没有区别。 Cython让事情变得更糟。好吧,也许Cython大师可以提出更好的解决方案......