在Cython中定义NumPy数组而不会产生python开销

时间:2013-08-05 02:27:12

标签: python python-2.7 numpy cython

我一直在努力学习Cython来加速我的一些计算。这是我要做的事情的一个子集:这只是在使用NumPy数组的同时使用递归公式简化微分方程。我已经实现了比纯python版本提高约100倍速度的因素。但是,基于查看-a cython命令为我的代码生成的HTML文件,我似乎可以获得更快的速度。我的代码如下(在HTML文件中变成黄色的行,我想要标记为白色):

%%cython
import numpy as np
cimport numpy as np
cimport cython
from libc.math cimport exp,sqrt

@cython.boundscheck(False)
cdef double riccati_int(double j, double w, double h, double an, double d):
    cdef:
        double W
        double an1
    W = sqrt(w**2 + d**2)
    #dark_yellow
    an1 = ((d - (W + w) * an) * exp(-2 * W * h / j ) - d - (W - w) * an) / 
          ((d * an - W + w) * exp(-2 * W * h / j) - d * an - W - w) 
    return an1


def acalc(double j, double w):
    cdef:
        int xpos, i, n
        np.ndarray[np.int_t, ndim=1] xvals
        np.ndarray[np.double_t, ndim=1] h, a
    xpos = 74
    xvals = np.array([0, 8, 23, 123, 218], dtype=np.int)     #dark_yellow
    h = np.array([1, .1, .01, .1], dtype=np.double)          #dark_yellow
    a = np.empty(219, dtype=np.double)                       #dark_yellow
    a[0] = 1 / (w + sqrt(w**2 + 1))                          #light_yellow

    for i in range(h.size):                                  #dark_yellow
        for n in range(xvals[i], xvals[i + 1]):              #light_yellow
            if n < xpos:
                a[n+1] = riccati_int(j, w, h[i], a[n], 1.)   #light_yellow
            else:
                a[n+1] = riccati_int(j, w, h[i], a[n], 0.)   #light_yellow
    return a  

在我看来,我上面标记的所有9条线应该能够通过适当的调整变成白色。一个问题是能够以正确的方式定义NumPy数组。但可能更重要的是能够使第一个标记线有效工作,因为这是完成大部分计算的地方。我尝试在点击黄线后阅读HTML文件显示的生成的C代码,但老实说我不知道​​如何阅读该代码。如果有人能帮助我,我们将不胜感激。

2 个答案:

答案 0 :(得分:1)

我认为你不需要关心不在循环中的黄线。添加以下编译器指令将使循环中的三行更快:

@cython.cdivision(True)
cdef double riccati_int(double j, double w, double h, double an, double d):
    pass

@cython.boundscheck(False)
@cython.wraparound(False)
def acalc(double j, double w):
    pass

答案 1 :(得分:0)

我不确定,它是否有所作为,但你可以对数组使用内存视图,例如:克。

cdef double [:] h = np.array([1, .1, .01, .1], dtype=np.double) #dark_yellow
cdef double [:] a = np.empty(219, dtype=np.double)              #dark_yellow

同样为四个静态值创建一个numpy数组有点过头了。这可以用静态C数组替换

cdef double *h = [1, .1, .01, .1]

然而,如上所述,循环中的内容最重要。由于行分析器不适用于cython(afaik),除了使用time之外,还可以使用cProfile模块在​​函数内进行基准测试。它可能会给你一个想法,即必须在上下文中评估cython日志中线条颜色的强度。

建议使用python类型进行索引,as I learned

size_t i, n
Py_ssize_t i, n

第二个是签名版本