Cython,numpy加速

时间:2015-10-10 22:27:57

标签: python performance numpy cython

我正在尝试编写一种算法来计算2D数组的某些相邻元素的平均值。

我想看看是否可以使用Cython加速它,但这是我第一次使用它。

Python版

import numpy as np

def clamp(val, minval, maxval):
    return max(minval, min(val, maxval))


def filter(arr, r):
    M = arr.shape[0]
    N = arr.shape[1]

    new_arr = np.zeros([M, N], dtype=np.int)

    for x in range(M):
        for y in range(N):
            # Corner elements
            p1 = clamp(x-r, 0, M)
            p2 = clamp(y-r, 0, N)
            p3 = clamp(y+r, 0, N-1)
            p4 = clamp(x+r, 0, M-1)

            nbr_elements = (p3-p2-1)*2+(p4-p1-1)*2+4

            tmp = 0

            # End points
            tmp += arr[p1, p2]
            tmp += arr[p1, p3]
            tmp += arr[p4, p2]
            tmp += arr[p4, p3]

            # The rest
            tmp += sum(arr[p1+1:p4, p2])
            tmp += sum(arr[p1+1:p4, p3])
            tmp += sum(arr[p1, p2+1:p3])
            tmp += sum(arr[p4, p2+1:p3])

            new_arr[x, y] = tmp/nbr_elements

    return new_arr

和我尝试Cython实现。我发现如果你重新实现它们,max / min / sum会更快,而不是使用python版本

Cython版

from __future__ import division
import numpy as np
cimport numpy as np

DTYPE = np.int
ctypedef np.int_t DTYPE_t

cdef inline int int_max(int a, int b): return a if a >= b else b
cdef inline int int_min(int a, int b): return a if a <= b else b

def clamp(int val, int minval, int maxval):
    return int_max(minval, int_min(val, maxval))

def cython_sum(np.ndarray[DTYPE_t, ndim=1] y):
    cdef int N = y.shape[0]
    cdef int x = y[0]
    cdef int i
    for i in xrange(1, N):
        x += y[i]
    return x


def filter(np.ndarray[DTYPE_t, ndim=2] arr, int r):
    cdef M = im.shape[0]
    cdef N = im.shape[1]

    cdef np.ndarray[DTYPE_t, ndim=2] new_arr = np.zeros([M, N], dtype=DTYPE)
    cdef int p1, p2, p3, p4, nbr_elements, tmp

    for x in range(M):
        for y in range(N):
            # Corner elements
            p1 = clamp(x-r, 0, M)
            p2 = clamp(y-r, 0, N)
            p3 = clamp(y+r, 0, N-1)
            p4 = clamp(x+r, 0, M-1)

            nbr_elements = (p3-p2-1)*2+(p4-p1-1)*2+4

            tmp = 0

            # End points
            tmp += arr[p1, p2]
            tmp += arr[p1, p3]
            tmp += arr[p4, p2]
            tmp += arr[p4, p3]

            # The rest
            tmp += cython_sum(arr[p1+1:p4, p2])
            tmp += cython_sum(arr[p1+1:p4, p3])
            tmp += cython_sum(arr[p1, p2+1:p3])
            tmp += cython_sum(arr[p4, p2+1:p3])

            new_arr[x, y] = tmp/nbr_elements

    return new_arr

我制作了一个测试脚本:

import time
import numpy as np

import square_mean_py
import square_mean_cy

N = 500

arr = np.random.randint(15, size=(N, N))
r = 8

# Timing

t = time.time()
res_py = square_mean_py.filter(arr, r)
print time.time()-t

t = time.time()
res_cy = square_mean_cy.filter(arr, r)
print time.time()-t

打印

9.61458301544
1.44476890564

这是一个大约加速的速度。 7次。我已经看到很多Cython实现了更好的加速,所以我想也许你们中的一些人看到了加速算法的潜在方法?

1 个答案:

答案 0 :(得分:3)

您的Cython脚本存在一些问题:

  1. 您没有向Cython提供一些关键信息,例如范围中使用的x, y, MN的类型。
  2. cdef编辑了两个函数cython_sumclamp,因为您在Python级别不需要它们。
  3. im功能中出现的filter是什么?我假设你的意思是arr
  4. 修复那些我将重写/修改你的Cython脚本:

    from __future__ import division
    import numpy as np
    cimport numpy as np
    from cython cimport boundscheck, wraparound 
    
    DTYPE = np.int
    ctypedef np.int_t DTYPE_t
    
    cdef inline int int_max(int a, int b): return a if a >= b else b
    cdef inline int int_min(int a, int b): return a if a <= b else b
    
    cdef int clamp3(int val, int minval, int maxval):
        return int_max(minval, int_min(val, maxval))
    
    @boundscheck(False)
    cdef int cython_sum2(DTYPE_t[:] y):
        cdef int N = y.shape[0]
        cdef int x = y[0]
        cdef int i
        for i in range(1, N):
            x += y[i]
        return x
    
    @boundscheck(False)
    @wraparound(False)
    def filter3(DTYPE_t[:,::1] arr, int r):
        cdef int M = arr.shape[0]
        cdef int N = arr.shape[1]
    
        cdef np.ndarray[DTYPE_t, ndim=2, mode='c'] \
        new_arr = np.zeros([M, N], dtype=DTYPE)
        cdef int p1, p2, p3, p4, nbr_elements, tmp, x, y
    
        for x in range(M):
            for y in range(N):
                # Corner elements
                p1 = clamp3(x-r, 0, M)
                p2 = clamp3(y-r, 0, N)
                p3 = clamp3(y+r, 0, N-1)
                p4 = clamp3(x+r, 0, M-1)
    
                nbr_elements = (p3-p2-1)*2+(p4-p1-1)*2+4
    
                tmp = 0
    
                # End points
                tmp += arr[p1, p2]
                tmp += arr[p1, p3]
                tmp += arr[p4, p2]
                tmp += arr[p4, p3]
    
                # The rest
                tmp += cython_sum2(arr[p1+1:p4, p2])
                tmp += cython_sum2(arr[p1+1:p4, p3])
                tmp += cython_sum2(arr[p1, p2+1:p3])
                tmp += cython_sum2(arr[p4, p2+1:p3])
    
                new_arr[x, y] = <int>(tmp/nbr_elements)
    
        return new_arr
    

    这是我机器上的时间:

    arr = np.random.randint(15, size=(500, 500))
    
    Original (Python) version: 7.34 s
    Your Cython version: 1.98 s
    New Cython version: 0.0323 s
    

    这比你的Cython脚本加速了近60倍,比原始Python脚本快了200倍。