为什么numpy.where比替代品要快得多

时间:2019-03-12 14:14:31

标签: python performance numpy

im试图加快以下代码的速度:

import time
import numpy as np
np.random.seed(10)
b=np.random.rand(10000,1000)
def f(a=1):
    tott=0
    for _ in range(a):
        q=np.array(b)
        t1 = time.time()
        for i in range(len(q)):
            for j in range(len(q[0])):
                if q[i][j]>0.5:
                    q[i][j]=1
                else:
                    q[i][j]=-1
        t2=time.time()
        tott+=t2-t1
    print(tott/a)

如您所见,func主要是关于双循环进行迭代。因此,我尝试使用np.nditernp.vectorizemap代替它。如果给予一定的加速(除了np.nditer以外,可以提高4-5倍),但是! np.where(q>0.5,1,-1)的加速几乎是100倍。 我如何以np.where的速度遍历numpy数组?而且为什么这么快?

1 个答案:

答案 0 :(得分:4)

要回答此问题,您可以使用numba库获得相同的速度(100倍加速度):

from numba import njit

def f(b):
    q = np.zeros_like(b)

    for i in range(b.shape[0]):
        for j in range(b.shape[1]):
            if q[i][j] > 0.5:
                q[i][j] = 1
            else:
                q[i][j] = -1

    return q

@njit
def f_jit(b):
    q = np.zeros_like(b)

    for i in range(b.shape[0]):
        for j in range(b.shape[1]):
            if q[i][j] > 0.5:
                q[i][j] = 1
            else:
                q[i][j] = -1

    return q

比较速度:

普通Python

%timeit f(b)
592 ms ± 5.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Numba(使用LLVM〜C速度实时编译)

%timeit f_jit(b)
5.97 ms ± 105 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)