我想我的问题标题可能不太清楚。
我有一个很小的数组,例如a = ([[0,0,0],[0,0,1],[0,1,1]])
。然后,我得到了更大维度的更大数组,例如b = ([[[2,2,2],[2,0,1],[2,1,1]],[[0,0,0],[3,3,1],[3,1,1]],[...]])
。
我想检查是否可以在b
中找到a的元素之一。在这种情况下,我发现[0,0,0]
的第一个元素确实在b
中,然后我想在b
中检索相应的索引。
我想避免循环,因为从我对numpy数组了解的很少,它们并不意味着以经典方式进行迭代。换句话说,我需要非常快,因为我的实际数组很大。
有什么主意吗? 非常感谢!
Arnaud。
答案 0 :(得分:1)
我不知道直接的方法,但是我这里有一个解决该问题的函数:
import numpy as np
def find_indices(val, arr):
# first take a mean at the lowest level of each array,
# then compare these to eliminate the majority of entries
mb = np.mean(arr, axis=2); ma = np.mean(val)
Y = np.argwhere(mb==ma)
indices = []
# Then run a quick loop on the remaining elements to
# eliminate arrays that don't match the order
for i in range(len(Y)):
idx = (Y[i,0],Y[i,1])
if np.array_equal(val, arr[idx]):
indices.append(idx)
return indices
# Sample arrays
a = np.array([[0,0,0],[0,0,1],[0,1,1]])
b = np.array([ [[6,5,4],[0,0,1],[2,3,3]], \
[[2,5,4],[6,5,4],[0,0,0]], \
[[2,0,2],[3,5,4],[5,4,6]], \
[[6,5,4],[0,0,0],[2,5,3]] ])
print(find_indices(a[0], b))
# [(1, 2), (3, 1)]
print(find_indices(a[1], b))
# [(0, 1)]
这个想法是使用每个数组的平均值并将其与输入的平均值进行比较。 np.argwhere()
是这里的关键。这样,您就可以删除大多数不需要的匹配项,但是我确实需要对其余部分使用循环,以避免未排序的匹配项(这不应该占用太多内存)。您可能想进一步对其进行自定义,但是希望对您有所帮助。