如何检查较大形状的numpy数组中是否存在给定的numpy数组?

时间:2018-11-19 07:51:24

标签: numpy

我想我的问题标题可能不太清楚。

我有一个很小的数组,例如a = ([[0,0,0],[0,0,1],[0,1,1]])。然后,我得到了更大维度的更大数组,例如b = ([[[2,2,2],[2,0,1],[2,1,1]],[[0,0,0],[3,3,1],[3,1,1]],[...]])

我想检查是否可以在b中找到a的元素之一。在这种情况下,我发现[0,0,0]的第一个元素确实在b中,然后我想在b中检索相应的索引。

我想避免循环,因为从我对numpy数组了解的很少,它们并不意味着以经典方式进行迭代。换句话说,我需要非常快,因为我的实际数组很大。

有什么主意吗? 非常感谢!

Arnaud。

1 个答案:

答案 0 :(得分:1)

我不知道直接的方法,但是我这里有一个解决该问题的函数:

import numpy as np

def find_indices(val, arr):
    # first take a mean at the lowest level of each array,
    # then compare these to eliminate the majority of entries
    mb = np.mean(arr, axis=2); ma = np.mean(val)
    Y = np.argwhere(mb==ma)
    indices = []
    # Then run a quick loop on the remaining elements to
    # eliminate arrays that don't match the order
    for i in range(len(Y)):
        idx = (Y[i,0],Y[i,1])
        if np.array_equal(val, arr[idx]):
            indices.append(idx)

    return indices

# Sample arrays
a = np.array([[0,0,0],[0,0,1],[0,1,1]])
b = np.array([ [[6,5,4],[0,0,1],[2,3,3]],   \
               [[2,5,4],[6,5,4],[0,0,0]],   \
               [[2,0,2],[3,5,4],[5,4,6]],   \
               [[6,5,4],[0,0,0],[2,5,3]]  ])



print(find_indices(a[0], b))
# [(1, 2), (3, 1)]
print(find_indices(a[1], b))
# [(0, 1)]

这个想法是使用每个数组的平均值并将其与输入的平均值进行比较。 np.argwhere()是这里的关键。这样,您就可以删除大多数不需要的匹配项,但是我确实需要对其余部分使用循环,以避免未排序的匹配项(这不应该占用太多内存)。您可能想进一步对其进行自定义,但是希望对您有所帮助。