再没有Cython的速度提升?

时间:2014-02-17 18:35:12

标签: python performance cython

以下是我的cython代码,目的是做一个bootstrap。

def boots(int trial, np.ndarray[double, ndim=2] empirical, np.ndarray[double, ndim=2] expected):
    cdef int length = len(empirical)
    cdef np.ndarray[double, ndim=2] ret = np.empty((trial, 100))
    cdef np.ndarray[long] choices
    cdef np.ndarray[double] m
    cdef np.ndarray[double] n
    cdef long o
    cdef int i
    cdef int j

    for i in range(trial):
        choices = np.random.randint(0, length, length)

        m = np.zeros(100)
        n = np.zeros(100)
        for j in range(length):
            o = choices[j]
            m.__iadd__(empirical[o])
            n.__iadd__(expected[o])
        empirical_boot = m / length
        expected_boot = n / length

        ret[i] = empirical_boot / expected_boot - 1
    ret.sort(axis=0)
    return ret[int(trial * 0.025)].reshape((10,10)), ret[int(trial * 0.975)].reshape((10,10))


# test code
empirical = np.ones((40000, 100))
expected = np.ones((40000, 100))
%prun -l 10 boots(100, empirical,expected)

纯粹的python需要花费11秒才能获得精美的索引,无论我在cython中如何努力,它都保持不变。

np.random.randint(0, 40000, 40000)需要1 ms,因此100x需要0.1秒。

np.sort(np.ones((40000, 100))需要0.2秒。

因此,我觉得必须有办法改进boots

1 个答案:

答案 0 :(得分:3)

您看到的主要问题是Cython仅针对类型化数组优化单项访问。这意味着您在NumPy中使用矢量化的代码中的每一行仍然涉及创建Python对象并与之交互。 你在那里的代码并不比纯Python版本快,因为它并没有真正以不同的方式进行任何计算。 您必须通过明确写出循环操作来避免这种情况。 以下是代码的修改版本,运行速度要快得多。

from numpy cimport ndarray as ar
from numpy cimport int32_t as int32
from numpy import empty
from numpy.random import randint
cimport cython
ctypedef int

# Notice the use of these decorators to tell Cython to turn off
# some of the checking it does when accessing arrays.
@cython.boundscheck(False)
@cython.wraparound(False)
def boots(int32 trial, ar[double, ndim=2] empirical, ar[double, ndim=2] expected):
    cdef:
        int32 length = empirical.shape[0], i, j, k
        int32 o
        ar[double, ndim=2] ret = empty((trial, 100))
        ar[int32] choices
        ar[double] m = empty(100), n = empty(100)
    for i in range(trial):
        # Still calling Python on this line
        choices = randint(0, length, length)
        # It was faster to compute m and n separately.
        # I suspect that has to do with cache management.
        # Instead of allocating new arrays, I just filled the old ones with the new values.
        o = choices[0]
        for k in range(100):
            m[k] = empirical[o,k]
        for j in range(1, length):
            o = choices[j]
            for k in range(100):
                m[k] += empirical[o,k]
        o = choices[0]
        for k in range(100):
            n[k] = expected[o,k]
        for j in range(1, length):
            o = choices[j]
            for k in range(100):
                n[k] += expected[o,k]
        # Here I simplified some of the math and got rid of temporary arrays
        for k in range(100):
            ret[i,k] = m[k] / n[k] - 1.
    ret.sort(axis=0)
    return ret[int(trial * 0.025)].reshape((10,10)), ret[int(trial * 0.975)].reshape((10,10))

如果您想查看代码的哪些行涉及Python调用,Cython编译器可以生成一个html文件,显示哪些行调用Python。 此选项称为注释。 你使用它的方式取决于你如何编译你的cython代码。 如果您使用的是IPython笔记本,只需将--annotate标志添加到Cython单元格魔术中。

您也可以从打开C编译器优化标志中受益。