在Python中快速检查范围

时间:2014-05-13 18:54:30

标签: python optimization micro-optimization

我有[(1, 1000), (5000, 5678), ... ]形式的很多范围。我正试图找出检查数字是否在任何范围内的最快方法。范围由longs组成,并且太大而无法保留所有数字的set

最简单的解决方案是:

ranges = [(1,5), (10,20), (40,50)]  # The real code has a few dozen ranges
nums = range(1000000)  
%timeit [n for n in nums if any([r[0] <= n <= r[1] for r in ranges])]
# 1 loops, best of 3: 5.31 s per loop

榕树有点快:

import banyan
banyan_ranges = banyan.SortedSet(updator=banyan.OverlappingIntervalsUpdator)
for r in ranges:
    banyan_ranges.add(r)
%timeit [n for n in nums if len(banyan_ranges.overlap_point(n))>0]
# 1 loops, best of 3: 452 ms per loop

虽然只有几十个范围,但是对这些范围进行了数百万次检查。这些检查的最快方法是什么?

(注意:这个问题类似于Python: efficiently check if integer is within *many* ranges,但没有相同的Django相关限制,只关注速度)

3 个答案:

答案 0 :(得分:8)

要尝试的事情:

  1. 预处理您的范围,使它们不重叠,并将其表示为半开间隔。
  2. 使用bisect模块进行搜索。 (不要手动实现自己的二分搜索!)请注意,预处理为1时,您需要知道的是bisect调用的结果是偶数还是奇数。 / LI>
  3. 如果批量查询是一个选项,请考虑将输入分组并使用numpy.searchsorted
  4. 一些代码和时间。首先是设置(这里使用IPython 2.1和Python 3.4):

    In [1]: ranges = [(1, 5), (10, 20), (40, 50)]
    
    In [2]: nums = list(range(1000000))  # force a list to remove generator overhead
    

    我机器上原始方法的计时(但是使用生成器表达而不是列表推导):

    In [3]: %timeit [n for n in nums if any(r[0] <= n <= r[1] for r in ranges)]
    1 loops, best of 3: 922 ms per loop
    

    现在我们将范围重新修改为边界点列表; 偶数索引处的每个边界点是其中一个范围的入口点,而奇数索引处的每个边界点都是一个出口点。请注意转换为半开的时间间隔,并且我已将所有数字放入单个列表中。

    In [4]: boundaries = [1, 6, 10, 21, 40, 51]
    

    使用此功能,bisect.bisect可以轻松获得与以前相同的结果,但速度更快。

    In [5]: from bisect import bisect
    
    In [6]: %timeit [n for n in nums if bisect(boundaries, n) % 2]
    1 loops, best of 3: 298 ms per loop
    

    最后,根据上下文,您可以使用NumPy中的searchsorted函数。这与bisect.bisect类似,但同时对整个值集合进行操作。例如:

    In [7]: import numpy
    
    In [8]: numpy.where(numpy.searchsorted(boundaries, nums, side="right") % 2)[0]
    Out[8]: 
    array([ 1,  2,  3,  4,  5, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 40,
           41, 42, 43, 44, 45, 46, 47, 48, 49, 50])
    

    乍一看,%timeit的结果相当令人失望。

    In [9]: %timeit numpy.where(numpy.searchsorted(boundaries, nums, side="right") % 2)[0]
    10 loops, best of 3: 159 ms per loop
    

    然而,事实证明,大部分性能成本是将输入转换为searchsorted从Python列表到NumPy数组。让我们将两个列表预转换为数组,然后重试:

    In [10]: boundaries = numpy.array(boundaries)
    
    In [11]: nums = numpy.array(nums)
    
    In [12]: %timeit numpy.where(numpy.searchsorted(boundaries, nums, side="right") % 2)[0]
    10 loops, best of 3: 24.6 ms per loop
    
    到目前为止,

    很多比其他任何事情都要快。但是,这有点作弊:我们当然可以预处理boundaries将其转换为数组,但如果要测试的值不是以数组形式自然生成的,那么转换成本将需要被考虑在内。另一方面,它表明搜索本身的成本可以降低到足够小的值,使其不再可能成为运行时间的主导因素。

    这是这些方面的另一种选择。它再次使用NumPy,但每个值都进行直接的非延迟线性搜索。 (请原谅无序IPython提示:我稍后添加了这个提示。: - )

    In [29]: numpy.where(numpy.logical_xor.reduce(numpy.greater_equal.outer(boundaries, nums), axis=0))
    Out[29]: 
    (array([ 2,  3,  4,  5,  6, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 41,
            42, 43, 44, 45, 46, 47, 48, 49, 50, 51]),)
    
    In [30]: %timeit numpy.where(numpy.logical_xor.reduce(numpy.greater_equal.outer(boundaries, nums), axis=0))
    10 loops, best of 3: 16.7 ms per loop
    

    对于这些特定的测试数据,这比searchsorted快,但时间将以范围数量线性增长,而对于searchsorted,它应该根据范围数量的对数增长。请注意,它还使用与len(boundaries) * len(nums)成比例的内存量。这不一定是个问题:如果你发现自己遇到了内存限制,你可以将数组分成更小的大小(一次说10000个元素),而不会失去太多的性能。

    向上移动比例,如果这些都不合适,我接下来尝试Cython和NumPy,编写一个搜索函数(输入声明为int的数组),在{{1上进行简单的线性搜索数组。我尝试了这个,但未能获得比基于boundaries的结果更好的结果。作为参考,这里是我试过的Cython代码;你可能会做得更好:

    bisect.bisect

    时间安排:

    cimport cython
    
    cimport numpy as np
    
    @cython.boundscheck(False)
    @cython.wraparound(False)
    def search(np.ndarray[long, ndim=1] boundaries, long val):
        cdef long j, k, n=len(boundaries)
        for j in range(n):
            if boundaries[j] > val:
               return j & 1
        return 0
    

答案 1 :(得分:2)

@ ArminRigo评论的实现,非常快。时间来自CPython,而不是PyPy:

exec_code = "def in_range(x):\n"
first_if = True
for r in ranges:
   if first_if:
      exec_code += "    if "
      first_if = False
   else:
      exec_code += "    elif "
   exec_code += "%d <= x <= %d: return True\n" % (r[0], r[1])
exec_code += "    return False"
exec(exec_code)

%timeit [n for n in nums if in_range(n)]
# 10 loops, best of 3: 173 ms per loop

答案 2 :(得分:1)

尝试使用二分搜索而不是线性搜索。它应该及时花费“Log(n)”。见下文:

list = []
for num in nums:
    start = 0
    end = len(ranges)-1
    if ranges[start][0] <= num <= ranges[start][1]:
        list.append(num)
    elif ranges[end][0] <= num <= ranges[end][1]:
        list.append(num):
    else:
        while end-start>1:
            mid = int(end+start/2)
            if ranges[mid][0] <= num <= ranges[mid][1]:
                list.append(num)
                break
            elif num < ranges[mid][0]:
                end = mid
            else:
                start = mid