提高cython数组索引速度

时间:2014-05-08 18:58:58

标签: python arrays numpy cython

我有一个非常简单的功能,我需要加快速度。基本上我有一个16位数字的大数组,其中有一些洞。 (大约10%)我需要遍历数组,找到连续有2 0的区域,然后用前一个和下一个元素的平均值填充它们。这在C中只需要几毫秒,但Python正在变得更糟。

我已经从常规python数组转换为numpy数组,然后使用cython编译我的代码,但我仍然离我的目标很远。我希望有更多经验的人可以看看我在做什么并给我一些反馈。

我的常规python代码如下所示:

self.rawData = numpy.fromfile(ql, numpy.uint16, 50000)
[snip]
def fixZeroes(self):
    for x in range(2,len(self.rawData)):
        if self.rawData[x] == 0 and self.rawData[x-1] == 0:
            self.rawData[x] = (self.rawData[x-2] + self.rawData[x+2]) / 2
            self.rawData[x-1] = (self.rawData[x-3] + self.rawData[x+1]) /2

我的Cython代码看起来非常相似:

import numpy as np
cimport numpy as np
DTYPE = np.uint16
ctypedef np.uint16_t DTYPE_t

@cython.boundscheck(False)
def fix_zeroes(np.ndarray[DTYPE_t, ndim=1] raw):
    assert raw.dtype == DTYPE
    cdef int len = 50000

    for x in range(2,len):
        if raw[x] == 0 and raw[x-1] == 0:
            raw[x] = (raw[x-2] + raw[x+2]) / 2
        raw[x-1] = (raw[x-3] + raw[x+1]) /2
    return raw

当我运行此代码时,性能仍然比我喜欢的速度慢:

  

启动cython零修复

     

完成:0:00:36.983681

     

启动python零修复

     

完成:0:00:41.434476

我真的认为我一定做错了。我所看到的大多数文章都谈到了numpy和cython增加的巨大性能提升,但我几乎没有突破10%。

1 个答案:

答案 0 :(得分:5)

您应该声明用于索引x数组的raw变量:

cdef int x

您还可以使用通常可以提升性能的其他指令:

@cython.wraparound(False)
@cython.cdivision(True)
@cython.nonecheck(False)