在numpy的示例代码中正确使用Cython

时间:2015-01-18 02:27:55

标签: python numpy cython

我想知道在使用Cyump和Numpy时我是否遗漏了一些东西,因为我没有看到太多改进。我把这段代码写成了一个例子。

朴素版:

import numpy as np
from skimage.util import view_as_windows

it = 16
arr = np.arange(1000*1000, dtype=np.float64).reshape(1000,1000)
windows = view_as_windows(arr, (it, it), it)
container = np.zeros((windows.shape[0], windows.shape[1]))
def test(windows):
    for i in range(windows.shape[0]):
        for j in range(windows.shape[1]):
            container[i,j] = np.mean(windows[i,j])
    return container

%%timeit 

test(windows)
1 loops, best of 3: 131 ms per loop

Cythonized版本:

%%cython --annotate

import numpy as np
cimport numpy as np
from skimage.util import view_as_windows
import cython
cdef int step = 16

arr = np.arange(1000*1000, dtype=np.float64).reshape(1000,1000)
windows = view_as_windows(arr, (step, step), step)

@cython.boundscheck(False)
def cython_test(np.ndarray[np.float64_t, ndim=4]  windows):
    cdef np.ndarray[np.float64_t, ndim=2] container = np.zeros((windows.shape[0], windows.shape[1]),dtype=np.float64)
    cdef int i, j
    I = windows.shape[0]
    J = windows.shape[1]
    for i in range(I):
        for j in range(J):
            container[i,j] = np.mean(windows[i,j])
    return container


%timeit cython_test(windows)
10 loops, best of 3: 126 ms per loop

正如你所看到的,有一个非常适度的改进,所以也许我做错了什么。顺便说一下,Cython产生的注释如下:

enter image description here

正如您所看到的,即使包含有效的索引语法np.ndarray[DTYPE_t, ndim=2],numpy行仍具有黄色背景。为什么呢?

顺便说一句,在我看来,理想的结果是能够使用大多数numpy函数,但在利用高效的索引语法或者像HYRY的答案中的内存视图后仍然得到一些合理的改进。

更新

似乎我在上面发布的代码中没有做任何错误,并且某些行中的黄色背景是正常的,所以我想知道以下内容:在哪些情况下我可以从键入{{1 numpy数组前面?我想有一些具体的例子,这有用,否则就没有太多的目的。

1 个答案:

答案 0 :(得分:3)

你需要自己实现mean()函数来加速代码,这是因为调用numpy函数的开销非常高。

@cython.boundscheck(False)
@cython.wraparound(False)
def cython_test(double[:, :, :, :]  windows):
    cdef double[:, ::1] container
    cdef int i, j, k, l
    cdef int n0, n1, n2, n3
    cdef double inv_n
    cdef double s
    n0, n1, n2, n3 = windows.base.shape
    container = np.zeros((n0, n1))
    inv_n = 1.0 / (n2 * n3)
    for i in range(n0):
        for j in range(n1):
            s = 0
            for k in range(n2):
                for l in range(n3):
                    s += windows[i, j, k, l]
            container[i,j] = s * inv_n
    return container.base

以下是%timeit结果:

  • python_test(windows):63.7 ms
  • cython_test(windows):1.24 ms
  • np.mean(windows, axis=(2, 3)):2.66 ms