im试图加快以下代码的速度:
import time
import numpy as np
np.random.seed(10)
b=np.random.rand(10000,1000)
def f(a=1):
tott=0
for _ in range(a):
q=np.array(b)
t1 = time.time()
for i in range(len(q)):
for j in range(len(q[0])):
if q[i][j]>0.5:
q[i][j]=1
else:
q[i][j]=-1
t2=time.time()
tott+=t2-t1
print(tott/a)
如您所见,func主要是关于双循环进行迭代。因此,我尝试使用np.nditer
,np.vectorize
和map
代替它。如果给予一定的加速(除了np.nditer
以外,可以提高4-5倍),但是! np.where(q>0.5,1,-1)
的加速几乎是100倍。
我如何以np.where
的速度遍历numpy数组?而且为什么这么快?
答案 0 :(得分:4)
要回答此问题,您可以使用numba
库获得相同的速度(100倍加速度):
from numba import njit
def f(b):
q = np.zeros_like(b)
for i in range(b.shape[0]):
for j in range(b.shape[1]):
if q[i][j] > 0.5:
q[i][j] = 1
else:
q[i][j] = -1
return q
@njit
def f_jit(b):
q = np.zeros_like(b)
for i in range(b.shape[0]):
for j in range(b.shape[1]):
if q[i][j] > 0.5:
q[i][j] = 1
else:
q[i][j] = -1
return q
比较速度:
普通Python
%timeit f(b)
592 ms ± 5.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Numba(使用LLVM〜C速度实时编译)
%timeit f_jit(b)
5.97 ms ± 105 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)