加速功能评估,以便在scipy中集成

时间:2015-02-02 13:02:15

标签: python numpy scipy

我正在尝试将代码从Matlab移植到SciPy。这是我到目前为止编写的代码的简化版本:https://gist.github.com/atmo/01b6e007be9ef90e402c。但是,Python版本比Matlab慢得多。我在gist中包含了性能分析结果,并且它们显示了几乎90%的时间python花费在评估函数f上。有没有办法加快评估速度,除非用C或Cython重写它?

2 个答案:

答案 0 :(得分:1)

您的numpy版本可能与旧的MATLAB运行速度相当。但新的MATLAB版本可以进行各种形式的即时编译,从而大大加快了重复计算的速度。

我的猜测是,您可以抄袭lambdaf代码,并可能将评估时间缩短一半。但真正的杀手是你多次打电话f

首先,我会尝试预先计算f中的内容。例如,在计算中定义K1=K[1]并使用K1。这将减少索引调用的数量。指数重复了吗?也许将lambda定义替换为常规def,或将其与f合并。

答案 1 :(得分:1)

正如我在评论中提到的,如果考虑到矩阵是对称的,那么你可以去掉大约一半的quad(以及复杂函数f)的调用

通过重写那个复杂的功能,仍然可以在纯python中获得进一步的速度提升。我在同情中完成了大部分内容。

最后,我尝试使用quad将调用向量化为np.vectorize

from scipy.integrate import quad
from scipy.special import jn as besselj
from scipy import exp, zeros, linspace
from scipy.linalg import norm
import numpy as np

def complicated_func(lmbd, a, n, k):
    u,v,w = 5, 3, 2
    x = a*lmbd
    fac = exp(2*x)
    comm = (2*w + x)
    part1 = ((v**2 + 4*w*(w + 2*x) + 2*x*(x - 1))*fac**5
        + 2*u*fac**4
        + (-v**2 - 4*(w*(3*w + 4*x + 1) + x*(x-2)) + 1)*fac**3
        + (-8*(w + x) + 2)*fac**2
        + (2*comm*(comm + 1) - 1)*fac)
    return part1/lmbd *besselj(n+1, lmbd) * besselj(k+1, lmbd)

def perform_quad(n, k, a):
    return quad(complicated_func, 0, np.inf, args=(a,n,k))[0]

def improved_main():
    sz = 20
    amatrix = np.zeros((sz,sz))
    ls = -np.linspace(1, 10, 20)/2
    inds = np.tril_indices(sz)
    myv3 = np.vectorize(perform_quad)
    res = myv3(inds[0], inds[1], ls.reshape(-1,1))
    results = np.empty(res.shape[0])
    for rowind, row in enumerate(res):
        amatrix[inds] = row
        symm_matrix = amatrix + amatrix.T - np.diag(amatrix.diagonal())
        results[rowind] = norm(symm_matrix)
    return results

计时结果显示我的速度增加了5倍(如果我只运行一次,你会原谅我,它需要足够长的时间):

In [11]: %timeit -n1 -r1 improved_main()
1 loops, best of 1: 6.92 s per loop

In [12]: %timeit -n1 -r1 main()
1 loops, best of 1: 35.9 s per loop

如果你用它的方块立即替换v,那么还有一个microgain,因为这是唯一一次用于那个复杂的函数:就像它的正方形一样。

besselj的调用也有极大的重复次数,但我看不出如何避免这种情况,因为quad会确定lmbd,所以你不能轻松预先计算这些值,然后执行查找。

如果您对improved_main进行了分析,您会发现complicated_func的呼叫数量几乎减少了2倍(对角线仍然需要计算)。所有其他速度增益可归因于np.vectorize以及complicated_func的改进。

我的系统上没有Matlab,所以如果你改进那里复杂的功能,我就不能说它的速度增益。