如何在3d numpy数组中找到2d数组的行

时间:2016-04-03 03:11:25

标签: python arrays numpy

我正在尝试找到一个二维数组出现在3d numpy ndarray中的行。这是我的意思的一个例子。得到:

arr = [[[0, 3], [3, 0]],
       [[0, 0], [0, 0]],
       [[3, 3], [3, 3]],
       [[0, 3], [3, 0]]]

我想找到所有出现的:

[[0, 3], [3, 0]]

我想要的结果是:

[0, 3]

我尝试使用argwhere,但不幸的是,这让我无处可去。有什么想法吗?

3 个答案:

答案 0 :(得分:5)

尝试

np.argwhere(np.all(arr==[[0,3], [3,0]], axis=(1,2)))

工作原理:

arr == [[0,3], [3,0]]返回

array([[[ True,  True],
        [ True,  True]],

       [[ True, False],
        [False,  True]],

       [[False,  True],
        [ True, False]],

       [[ True,  True],
        [ True,  True]]], dtype=bool)

这是一个三维数组,其中最内轴为2.此轴的值为:

[True, True]
[True, True]
[True, False]
[False, True]
[False, True]
[True, False]
[True, True]
[True, True]

现在使用np.all(arr==[[0,3], [3,0]], axis=2)检查行上的两个元素是否为True,其形状是否会从(4,2,2)缩小为(4,2)。像这样:

array([[ True,  True],
       [False, False],
       [False, False],
       [ True,  True]], dtype=bool)

你需要再减少一步,因为你希望它们都是相同的([0, 3][3, 0]。你可以通过减少结果(现在最里面的轴是1):

np.all(np.all(test, axis = 2), axis=1)

或者你也可以通过给轴参数的元组一步一步地做同样的事情(第一个最里面,然后再高一步)来做到这一点。结果将是:

array([ True, False, False,  True], dtype=bool)

答案 1 :(得分:2)

numpy_indexed包中的'contains'函数(免责声明:我是它的作者)可用于进行此类查询。它实现了类似于Saullo提供的解决方案。

import numpy_indexed as npi
test = [[[0, 3], [3, 0]]]
# check which elements of arr are present in test (checked along axis=0 by default)
flags = npi.contains(test, arr)
# if you want the indexes:
idx = np.flatnonzero(flags)

答案 2 :(得分:1)

您可以在定义新数据类型后使用np.in1d,该数据类型将包含arr中每行的内存大小。要定义此类数据类型:

mydtype = np.dtype((np.void, arr.dtype.itemsize*arr.shape[1]*arr.shape[2]))

然后您必须将arr转换为1-D数组,其中每行将包含arr.shape[1]*arr.shape[2]个元素:

aView = np.ascontiguousarray(arr).flatten().view(mydtype)

您现在已准备好寻找也必须转换为[[0, 3], [3, 0]]的二维数组模式dtype

bView = np.array([[0, 3], [3, 0]]).flatten().view(mydtype)

您现在可以查看bViewaView的发生次数:

np.in1d(aView, bView)
#array([ True, False, False,  True], dtype=bool)

例如,使用np.where可以轻松地将此蒙版转换为索引。

计时(更新)

以下函数用于实现此方法:

def check2din3d(b, a):
        """
        Return where `b` (2D array) appears in `a` (3D array) along `axis=0`
        """
        mydtype = np.dtype((np.void, a.dtype.itemsize*a.shape[1]*a.shape[2]))
        aView = np.ascontiguousarray(a).flatten().view(mydtype)
        bView = np.ascontiguousarray(b).flatten().view(mydtype)
        return np.in1d(aView, bView)

考虑到@ayhan评论的更新时间表明,这种方法可以更快地在np.argwhere,但是差异并不显着,对于像下面这样的大型数组,@ ayhan的方法要快得多:

arrLarge = np.concatenate([arr]*10000000)
arrLarge = np.concatenate([arrLarge]*10, axis=2)

pattern = np.ascontiguousarray([[0,3]*10, [3,0]*10])

%timeit np.argwhere(np.all(arrLarger==pattern, axis=(1,2)))
#1 loops, best of 3: 2.99 s per loop

%timeit check2din3d(pattern, arrLarger)
#1 loops, best of 3: 4.65 s per loop