排序数组更快的替代np.where

时间:2020-01-10 12:45:13

标签: python arrays numpy where-clause

给出一个沿每行排序的大型数组a,是否有比numpy的np.where更快的替代方法来找到索引min_v <= a <= max_v?我可以想象,利用数组的排序特性应该可以加快处理速度。

下面是使用np.where在大型数组中查找给定索引的设置示例。

import numpy as np

# Initialise an example of an array in which to search
r, c = int(1e2), int(1e6)
a = np.arange(r*c).reshape(r, c)

# Set up search limits
min_v = (r*c/2)-10
max_v = (r*c/2)+10

# Find indices of occurrences
idx = np.where(((a >= min_v) & (a <= max_v)))

2 个答案:

答案 0 :(得分:2)

您可以使用np.searchsorted

import numpy as np

r, c = 10, 100
a = np.arange(r*c).reshape(r, c)

min_v = ((r * c) // 2) - 10
max_v = ((r * c) // 2) + 10

# Old method
idx = np.where(((a >= min_v) & (a <= max_v)))

# With searchsorted
i1 = np.searchsorted(a.ravel(), min_v, 'left')
i2 = np.searchsorted(a.ravel(), max_v, 'right')
idx2 = np.unravel_index(np.arange(i1, i2), a.shape)
print((idx[0] == idx2[0]).all() and (idx[1] == idx2[1]).all())
# True

答案 1 :(得分:1)

在原始示例中,我使用np.searchsorted并使用了最新版本的NumPy 1.12.1(无法告知较新的版本)中的1亿个数字,它并没有比{ {1}}:

np.where

但是,尽管>>> import timeit >>> timeit.timeit('np.where(((a >= min_v) & (a <= max_v)))', number=10, globals=globals()) 6.685825735330582 >>> timeit.timeit('np.searchsorted(a.ravel(), [min_v, max_v])', number=10, globals=globals()) 5.304438766092062 的NumPy文档说此函数使用与内置python searchsortedbisect.bisect_left函数相同的算法,快很多

bisect.bisect_right

因此,我将使用它:

>>> import bisect
>>> timeit.timeit('bisect.bisect_left(a.base, min_v), bisect.bisect_right(a.base, max_v)', number=10, globals=globals())
0.002058468759059906