假设我有一个长度为4的numpy数组:
In [41]: arr
Out[41]:
array([[ 1, 15, 0, 0],
[ 30, 10, 0, 0],
[ 30, 20, 0, 0],
...,
[104, 139, 146, 75],
[ 9, 11, 146, 74],
[ 9, 138, 146, 75]], dtype=uint8)
我想知道:
arr
是否包括[1, 2, 3, 4]
?[1, 2, 3, 4]
中的arr
索引是什么?我想尽可能快地找到它。
假设arr
包含8550420个元素。我用timeit
检查了几种方法:
any(all([1, 2, 3, 4] == elt) for elt in arr)
。我机器上的10次运行平均耗时15.5秒 for
- 基于解决方案:
for i,e in enumerate(arr):
if list(e) == [1, 2, 3, 4]:
break
平均花费约5.7秒
是否存在一些更快的解决方案,例如基于numpy?
答案 0 :(得分:6)
这是Jaime's idea,我很喜欢它:
import numpy as np
def asvoid(arr):
"""View the array as dtype np.void (bytes)
This collapses ND-arrays to 1D-arrays, so you can perform 1D operations on them.
https://stackoverflow.com/a/16216866/190597 (Jaime)"""
arr = np.ascontiguousarray(arr)
return arr.view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[-1])))
def find_index(arr, x):
arr_as1d = asvoid(arr)
x = asvoid(x)
return np.nonzero(arr_as1d == x)[0]
arr = np.array([[ 1, 15, 0, 0],
[ 30, 10, 0, 0],
[ 30, 20, 0, 0],
[1, 2, 3, 4],
[104, 139, 146, 75],
[ 9, 11, 146, 74],
[ 9, 138, 146, 75]], dtype='uint8')
arr = np.tile(arr,(1221488,1))
x = np.array([1,2,3,4], dtype='uint8')
print(find_index(arr, x))
产量
[ 3 10 17 ..., 8550398 8550405 8550412]
我们的想法是将数组的每个行视为一个字符串。例如,
In [15]: x
Out[15]:
array([^A^B^C^D],
dtype='|V4')
字符串看起来像垃圾,但它们实际上只是每行被视为字节的基础数据。然后,您可以比较arr_as1d == x
以找到等于x
的行。
def find_index2(arr, x):
return np.where((arr == x).all(axis=1))[0]
但结果并不那么快:
In [34]: %timeit find_index(arr, x)
1 loops, best of 3: 209 ms per loop
In [35]: %timeit find_index2(arr, x)
1 loops, best of 3: 370 ms per loop
答案 1 :(得分:0)
如果你多次执行搜索并且不介意使用额外的内存,你可以从你的数组创建set(我在这里使用list,但它几乎是相同的代码):
>>> elem = [1, 2, 3, 4]
>>> elements = [[ 1, 15, 0, 0], [ 30, 10, 0, 0], [1, 2, 3, 4]]
>>> index = set([tuple(x) for x in elements])
>>> True if tuple(elem) in index else False
True