快速查找数组数组中的数组索引

时间:2013-07-23 13:10:32

标签: python arrays search numpy multidimensional-array

假设我有一个长度为4的numpy数组:

In [41]: arr
Out[41]:
array([[  1,  15,   0,   0],
       [ 30,  10,   0,   0],
       [ 30,  20,   0,   0],
       ...,
       [104, 139, 146,  75],
       [  9,  11, 146,  74],
       [  9, 138, 146,  75]], dtype=uint8)

我想知道:

  1. arr是否包括[1, 2, 3, 4]
  2. 如果确实[1, 2, 3, 4]中的arr索引是什么?
  3. 我想尽可能快地找到它。

    假设arr包含8550420个元素。我用timeit检查了几种方法:

    1. 仅用于检查而不获取索引:any(all([1, 2, 3, 4] == elt) for elt in arr)。我机器上的10次运行平均耗时15.5秒
    2. for - 基于解决方案:

      for i,e in enumerate(arr): if list(e) == [1, 2, 3, 4]: break

      平均花费约5.7秒

    3. 是否存在一些更快的解决方案,例如基于numpy?

2 个答案:

答案 0 :(得分:6)

这是Jaime's idea,我很喜欢它:

import numpy as np

def asvoid(arr):
    """View the array as dtype np.void (bytes)
    This collapses ND-arrays to 1D-arrays, so you can perform 1D operations on them.
    https://stackoverflow.com/a/16216866/190597 (Jaime)"""    
    arr = np.ascontiguousarray(arr)
    return arr.view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[-1])))

def find_index(arr, x):
    arr_as1d = asvoid(arr)
    x = asvoid(x)
    return np.nonzero(arr_as1d == x)[0]


arr = np.array([[  1,  15,   0,   0],
                [ 30,  10,   0,   0],
                [ 30,  20,   0,   0],
                [1, 2, 3, 4],
                [104, 139, 146,  75],
                [  9,  11, 146,  74],
                [  9, 138, 146,  75]], dtype='uint8')

arr = np.tile(arr,(1221488,1))
x = np.array([1,2,3,4], dtype='uint8')

print(find_index(arr, x))

产量

[      3      10      17 ..., 8550398 8550405 8550412]

我们的想法是将数组的每个视为一个字符串。例如,

In [15]: x
Out[15]: 
array([^A^B^C^D], 
      dtype='|V4')

字符串看起来像垃圾,但它们实际上只是每行被视为字节的基础数据。然后,您可以比较arr_as1d == x以找到等于x


There is another way这样做:

def find_index2(arr, x):
    return np.where((arr == x).all(axis=1))[0]

但结果并不那么快:

In [34]: %timeit find_index(arr, x)
1 loops, best of 3: 209 ms per loop

In [35]: %timeit find_index2(arr, x)
1 loops, best of 3: 370 ms per loop

答案 1 :(得分:0)

如果你多次执行搜索并且不介意使用额外的内存,你可以从你的数组创建set(我在这里使用list,但它几乎是相同的代码):

>>> elem = [1, 2, 3, 4]    
>>> elements = [[  1,  15,   0,   0], [ 30,  10,   0,   0], [1, 2, 3, 4]]
>>> index = set([tuple(x) for x in elements])
>>> True if tuple(elem) in index else False
True