为什么Numba无法改善此递归功能

时间:2020-06-14 12:17:07

标签: python arrays numpy binary-search numba

我有一个具有简单结构的true / false值数组:

# the real array has hundreds of thousands of items
positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)

我想遍历此数组并输出发生更改的位置(true变为false或相反)。为此,我提出了两种不同的方法:

  • 递归二进制搜索(查看所有值是否相同,如果不相同,则一分为二,然后递归)
  • 纯粹的迭代搜索(遍历所有元素并与上一个/下一个进行比较)

两个版本都能提供我想要的结果,但是Numba对一个的影响大于另一个。使用300k值的虚拟数组,以下是性能结果:

具有30万个元素的数组的性能结果

  • 纯Python二进制搜索的运行时间为11毫秒
  • 纯Python迭代搜索的运行时间为1.1秒(比二进制搜索慢100倍)
  • Numba二进制搜索在5毫秒内运行(比同等的Python快2倍)
  • Numba迭代搜索的运行时间为900 µs((比纯Python快1200倍)

结果是,使用Numba时,binary_search的速度比iterative_search慢5倍,而理论上应该快100倍(如果正确加速,则预计运行时间为9 µs)。

如何使Numba加速二分搜索和加速迭代搜索?

这两种方法的代码(以及示例position数组)都可以在以下公开要点上找到:https://gist.github.com/JivanRoquet/d58989aa0a4598e060ec2c705b9f3d8f

注意:Numba不在对象模式下运行binary_search(),因为在提及nopython=True时,它不会抱怨并很高兴地编译该函数。

3 个答案:

答案 0 :(得分:3)

主要问题是您没有执行苹果到苹果的比较。 您提供的不是同一算法的迭代版本和递归版本。 您正在提出两种根本不同的算法,它们恰好是递归/迭代的。

尤其是您在递归方法中使用了更多的NumPy内置函数,因此难怪这两种方法之间存在如此惊人的差异。当您避免使用NumPy内置组件时,Numba JITting更有效也就不足为奇了。 最终,递归算法似乎效率较低,因为np.all()np.any()调用中有一些 hidden 嵌套循环,因此避免了迭代方法,因此即使您为了用Numba更有效地用纯Python编写所有代码,递归方法会更慢。

通常,迭代方法比递归等效更快,因为它们避免了函数调用开销(与纯Python相比,JIT加速函数的开销最小)。 因此,我建议不要尝试以递归形式重写算法,以免发现速度较慢。


编辑

在一个简单的np.diff()就可以解决问题的前提下,Numba仍然会非常有益:

import numpy as np
import numba as nb


@nb.jit
def diff(arr):
    n = arr.size
    result = np.empty(n - 1, dtype=arr.dtype)
    for i in range(n - 1):
        result[i] = arr[i + 1] ^ arr[i]
    return result


positions = np.random.randint(0, 2, size=300_000, dtype=bool)
print(np.allclose(np.diff(positions), diff(positions)))
# True


%timeit np.diff(positions)
# 1000 loops, best of 3: 603 µs per loop
%timeit diff(positions)
# 10000 loops, best of 3: 43.3 µs per loop

使用Numba的方法快大约13倍(在此测试中,里程可能会有所不同)。

答案 1 :(得分:3)

您可以使用np.diff找到值变化的位置,无需运行更复杂的算法或使用numba

positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)
dpos = np.diff(positions)
# array([ True, False, False,  True, False, False, False,  True, False, False])

这有效,因为False - True == -1np.bool(-1) == True

在我的电池供电(=由于节能模式而节流)和几岁的笔记本电脑上,它的性能都很好:

In [52]: positions = np.random.randint(0, 2, size=300_000, dtype=bool)          

In [53]: %timeit np.diff(positions)                                             
633 µs ± 4.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

我想在numba中编写自己的差异应该会产生相似的性能。

编辑:最后一条语句是错误的,我使用numba实现了一个简单的diff函数,它比numpy快了10倍(但显然功能也少得多) ,但足以完成此任务):

@numba.njit 
def ndiff(x): 
    s = x.size - 1 
    r = np.empty(s, dtype=x.dtype) 
    for i in range(s): 
        r[i] = x[i+1] - x[i] 
    return r

In [68]: np.all(ndiff(positions) == np.diff(positions))                            
Out[68]: True

In [69]: %timeit ndiff(positions)                                               
46 µs ± 138 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

答案 2 :(得分:1)

要点是,只有使用Python机制的逻辑部分可以被加速-通过用等效的C逻辑替换它,可以消除Python运行时的大部分复杂性(和灵活性)(我想这就是Numba) )。

NumPy操作中所有繁重的工作已经在C语言中实现并且非常简单(因为NumPy数组是存储常规C类型的连续内存块),因此Numba只能剥离与Python机器接口的部分。

您的“二进制搜索”算法可以做更多的工作,并且在使用NumPy的矢量运算时会大量使用它,因此可以通过这种方式来加速它的工作。