NumPy vs Cython - 嵌套循环这么慢?

时间:2014-10-28 14:35:13

标签: python numpy

我很困惑,与Cython相比,NumPy嵌套循环的3D数组是如此之慢。 我写了一些简单的例子。

Python / NumPy版本:

import numpy as np

def my_func(a,b,c):
    s=0
        for z in xrange(401):
        for y in xrange(401):
            for x in xrange(401):
                if a[z,y,x] == 0 and b[x,y,z] >= 0:
                    c[z,y,x] = 1
                    b[z,y,x] = z*y*x
                    s+=1
    return s

a = np.zeros((401,401,401), dtype=np.float32)
b = np.zeros((401,401,401), dtype=np.uint32)
c = np.zeros((401,401,401), dtype=np.uint8)

s = my_func(a,b,c)

Cythonized版本:

cimport numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
def my_func(np.float32_t[:,:,::1] a, np.uint32_t[:,:,::1] b, np.uint8_t[:,:,::1] c):
    cdef np.uint16_t z,y,x
    cdef np.uint32_t s = 0

    for z in range(401):
        for y in range(401):
            for x in range(401):
                if a[z,y,x] == 0 and b[x,y,z] >= 0:
                    c[z,y,x] = 1
                    b[z,y,x] = z*y*x
                    s = s+1
    return s

my_func()的Cythonized版本大约运行。快6500倍。仅使用if语句和数组访问的简单函数甚至可以快10000倍。 Python版my_func()需要500.651秒。完成。迭代相对较小的3D阵列这么慢还是我在代码中犯了一些错误?

Cython版本0.21.1,Python 2.7.5,GCC 4.8.1,Xubuntu 13.10。

3 个答案:

答案 0 :(得分:5)

Python是一种解释型语言。编译到机器代码的一个好处是可以获得巨大的加速,尤其是嵌套循环等。

我不知道你的期望是什么,但所有解释的语言在你想要做的事情上会非常缓慢(JIT compiling在某种程度上可能会有所帮助)。

从Numpy(或MATLAB或类似的东西)中获得良好性能的技巧是完全避免循环,而是尝试将代码重构为大型矩阵上的一些操作。这样,循环将在(经过大量优化的)机器代码库中进行,而不是在Python代码中进行。

答案 1 :(得分:4)

正如Krumelur所提到的,python循环肯定很慢。但是,您可以使用numpy来获得优势。整个阵列上的操作非常快,尽管有时你需要一点点聪明才智。

例如,在你的代码中,因为你的循环在修改之后永远不会读取b中的值(我想?我的脑袋现在有点模糊,所以你绝对想要经历这个,以下应该是等价的:

# Precalculate a matrix of x*y*z
tmp = np.indices(a.shape)
prod = (tmp[:,:,:,0] * tmp[:,:,:,1] * tmp[:,:,:,2]).T

# Use array-wide logical operations to compute c using a and the transpose of b
condition = np.logical_and(a == 0, b.T >= 0)

# Use condition to alter b and c only where condition is true
b[condition] = prod[condition]
c[condition] = 1

s = condition.sum()

即使在条件为假的情况下,这也会计算x*y*z。如果事实证明使用了大量时间,你可能会避免这种情况,但它可能不是一个重要因素。

答案 2 :(得分:2)

对于在python中使用numpy数组的循环很慢,您应该尽可能使用向量计算。如果算法需要为数组中的每个元素循环,这里有一些加速提示。

a[z,y,x]是一个numpy标量值,使用numpy标量值的计算非常慢:

x = 3.0
%timeit x > 0

x = np.float64(3.0)
%timeit x > 0

我的电脑上的输出有numpy 1.8.2,windows 7:

10000000 loops, best of 3: 64.3 ns per loop
1000000 loops, best of 3: 657 ns per loop

您可以使用item()方法直接获取python值:

if a.item(z, y, x) == 0 and b.item(x, y, z) >= 0:
    ...

它可以将for循环加速大约8倍。