numba慢于numpy.bitwise_和布尔数组

时间:2015-12-28 21:18:11

标签: python numpy numba

我正在尝试使用此代码段中的numba

db = np.array([           # out value for mask = [1, 0, 1]
    # target,  vector     #
      [1,      1, 0, 1],  # 1
      [0,      1, 1, 1],  # 0 (fit to mask but target == 0)
      [0,      0, 1, 0],  # 0
      [1,      1, 0, 1],  # 1
      [0,      1, 1, 0],  # 0
      [1,      0, 0, 0],  # 0
      ])

使用numba @jit()装饰器,这段代码运行得更慢!

  • 没有jit:3.16秒
  • with jit:3.81 sec

只是为了帮助更好地理解此代码的目的:

enum

3 个答案:

答案 0 :(得分:4)

或者,您可以尝试Pythran (免责声明:我是Pythran的开发人员)。

使用单个注释,它会编译以下代码

#pythran export check_mask(bool[][], bool[])

import numpy as np
def check_mask(db, out, mask=[1, 0, 1]):
    for idx, line in enumerate(db):
        target, vector = line[0], line[1:]
        if (mask == np.bitwise_and(mask, vector)).all():
            if target == 1:
                out[idx] = 1
    return out

致电pythran check_call.py

根据timeit,生成的本机模块运行得非常快:

python -m timeit -s 'n=1e4; import numpy as np; db  = np.array(np.random.randint(2, size=(n, 4)), dtype=bool); out = np.zeros(int(n), dtype=bool); from eq import check_mask' 'check_mask(db, out)'

告诉我CPython版本在136ms中运行,而Pythran编译版本在450us中运行。

答案 1 :(得分:3)

Numba有jit的两种编译模式:nopython模式和对象模式。 Nopython模式(默认)仅支持一组有限的Python和Numpy功能,请参阅the docs for your version。如果jitted函数包含不受支持的代码,Numba必须回退到对象模式,这要慢很多。

我不确定objcet模式是否应该提供与纯Python相比的加速,但是你总是想要使用nopython模式。要确保使用nopython模式,请指定nopython=True并坚持使用非常基本的代码(经验法则:写出所有循环并仅使用标量和Numpy数组):

@jit(nopython=True)
def check_mask_2(db, out, mask=np.array([1, 0, 1])):
    for idx in range(db.shape[0]):
        if db[idx,0] != 1:
            continue
        check = 1
        for j in range(db.shape[1]):
            if mask[j] and not db[idx,j+1]:
                check = 0
                break
        out[idx] = check
    return out

明确地写出内部循环也有一个好处,就是我们可以在条件失败时立即将其中断。

时序:

%time _ = check_mask(db, out, np.array([1, 0, 1]))
# Wall time: 1.91 s
%time _ = check_mask_2(db, out, np.array([1, 0, 1]))
# Wall time: 310 ms  # slow because of compilation
%time _ = check_mask_2(db, out, np.array([1, 0, 1]))
# Wall time: 3 ms
顺便说一下,这个功能也很容易用Numpy进行矢量化,这样可以提供一个不错的速度:

def check_mask_vectorized(db, mask=[1, 0, 1]):
    check = (db[:,1:] == mask).all(axis=1)
    out = (db[:,0] == 1) & check
    return out

%time _ = check_mask_vectorized(db, [1, 0, 1])
# Wall time: 14 ms

答案 2 :(得分:1)

我建议从内循环中删除对array_equal的numpy调用。 numba不一定非常聪明,可以把它变成一个内联的C;如果它无法取代这个电话,你的功能的主要成本仍然具有可比性,这可以解释你的结果。

虽然numba可以推断出相当数量的numpy结构,但只有C风格的代码可以作用于numpy数组,而这些代码可能依赖于加速。