我很困惑,与Cython相比,NumPy嵌套循环的3D数组是如此之慢。 我写了一些简单的例子。
Python / NumPy版本:
import numpy as np
def my_func(a,b,c):
s=0
for z in xrange(401):
for y in xrange(401):
for x in xrange(401):
if a[z,y,x] == 0 and b[x,y,z] >= 0:
c[z,y,x] = 1
b[z,y,x] = z*y*x
s+=1
return s
a = np.zeros((401,401,401), dtype=np.float32)
b = np.zeros((401,401,401), dtype=np.uint32)
c = np.zeros((401,401,401), dtype=np.uint8)
s = my_func(a,b,c)
Cythonized版本:
cimport numpy as np
cimport cython
@cython.boundscheck(False)
@cython.wraparound(False)
def my_func(np.float32_t[:,:,::1] a, np.uint32_t[:,:,::1] b, np.uint8_t[:,:,::1] c):
cdef np.uint16_t z,y,x
cdef np.uint32_t s = 0
for z in range(401):
for y in range(401):
for x in range(401):
if a[z,y,x] == 0 and b[x,y,z] >= 0:
c[z,y,x] = 1
b[z,y,x] = z*y*x
s = s+1
return s
my_func()
的Cythonized版本大约运行。快6500倍。仅使用if语句和数组访问的简单函数甚至可以快10000倍。 Python版my_func()
需要500.651秒。完成。迭代相对较小的3D阵列这么慢还是我在代码中犯了一些错误?
Cython版本0.21.1,Python 2.7.5,GCC 4.8.1,Xubuntu 13.10。
答案 0 :(得分:5)
Python是一种解释型语言。编译到机器代码的一个好处是可以获得巨大的加速,尤其是嵌套循环等。
我不知道你的期望是什么,但所有解释的语言在你想要做的事情上会非常缓慢(JIT compiling在某种程度上可能会有所帮助)。
从Numpy(或MATLAB或类似的东西)中获得良好性能的技巧是完全避免循环,而是尝试将代码重构为大型矩阵上的一些操作。这样,循环将在(经过大量优化的)机器代码库中进行,而不是在Python代码中进行。
答案 1 :(得分:4)
正如Krumelur所提到的,python循环肯定很慢。但是,您可以使用numpy来获得优势。整个阵列上的操作非常快,尽管有时你需要一点点聪明才智。
例如,在你的代码中,因为你的循环在修改之后永远不会读取b
中的值(我想?我的脑袋现在有点模糊,所以你绝对想要经历这个,以下应该是等价的:
# Precalculate a matrix of x*y*z
tmp = np.indices(a.shape)
prod = (tmp[:,:,:,0] * tmp[:,:,:,1] * tmp[:,:,:,2]).T
# Use array-wide logical operations to compute c using a and the transpose of b
condition = np.logical_and(a == 0, b.T >= 0)
# Use condition to alter b and c only where condition is true
b[condition] = prod[condition]
c[condition] = 1
s = condition.sum()
即使在条件为假的情况下,这也会计算x*y*z
。如果事实证明使用了大量时间,你可能会避免这种情况,但它可能不是一个重要因素。
答案 2 :(得分:2)
对于在python中使用numpy数组的循环很慢,您应该尽可能使用向量计算。如果算法需要为数组中的每个元素循环,这里有一些加速提示。
a[z,y,x]
是一个numpy标量值,使用numpy标量值的计算非常慢:
x = 3.0
%timeit x > 0
x = np.float64(3.0)
%timeit x > 0
我的电脑上的输出有numpy 1.8.2,windows 7:
10000000 loops, best of 3: 64.3 ns per loop
1000000 loops, best of 3: 657 ns per loop
您可以使用item()
方法直接获取python值:
if a.item(z, y, x) == 0 and b.item(x, y, z) >= 0:
...
它可以将for循环加速大约8倍。