我正在尝试找到一个二维数组出现在3d numpy ndarray中的行。这是我的意思的一个例子。得到:
arr = [[[0, 3], [3, 0]],
[[0, 0], [0, 0]],
[[3, 3], [3, 3]],
[[0, 3], [3, 0]]]
我想找到所有出现的:
[[0, 3], [3, 0]]
我想要的结果是:
[0, 3]
我尝试使用argwhere
,但不幸的是,这让我无处可去。有什么想法吗?
答案 0 :(得分:5)
尝试
np.argwhere(np.all(arr==[[0,3], [3,0]], axis=(1,2)))
工作原理:
arr == [[0,3], [3,0]]
返回
array([[[ True, True],
[ True, True]],
[[ True, False],
[False, True]],
[[False, True],
[ True, False]],
[[ True, True],
[ True, True]]], dtype=bool)
这是一个三维数组,其中最内轴为2.此轴的值为:
[True, True]
[True, True]
[True, False]
[False, True]
[False, True]
[True, False]
[True, True]
[True, True]
现在使用np.all(arr==[[0,3], [3,0]], axis=2)
检查行上的两个元素是否为True
,其形状是否会从(4,2,2)缩小为(4,2)。像这样:
array([[ True, True],
[False, False],
[False, False],
[ True, True]], dtype=bool)
你需要再减少一步,因为你希望它们都是相同的([0, 3]
和[3, 0]
。你可以通过减少结果(现在最里面的轴是1):
np.all(np.all(test, axis = 2), axis=1)
或者你也可以通过给轴参数的元组一步一步地做同样的事情(第一个最里面,然后再高一步)来做到这一点。结果将是:
array([ True, False, False, True], dtype=bool)
答案 1 :(得分:2)
numpy_indexed包中的'contains'函数(免责声明:我是它的作者)可用于进行此类查询。它实现了类似于Saullo提供的解决方案。
import numpy_indexed as npi
test = [[[0, 3], [3, 0]]]
# check which elements of arr are present in test (checked along axis=0 by default)
flags = npi.contains(test, arr)
# if you want the indexes:
idx = np.flatnonzero(flags)
答案 2 :(得分:1)
您可以在定义新数据类型后使用np.in1d
,该数据类型将包含arr
中每行的内存大小。要定义此类数据类型:
mydtype = np.dtype((np.void, arr.dtype.itemsize*arr.shape[1]*arr.shape[2]))
然后您必须将arr
转换为1-D数组,其中每行将包含arr.shape[1]*arr.shape[2]
个元素:
aView = np.ascontiguousarray(arr).flatten().view(mydtype)
您现在已准备好寻找也必须转换为[[0, 3], [3, 0]]
的二维数组模式dtype
:
bView = np.array([[0, 3], [3, 0]]).flatten().view(mydtype)
您现在可以查看bView
中aView
的发生次数:
np.in1d(aView, bView)
#array([ True, False, False, True], dtype=bool)
例如,使用np.where
可以轻松地将此蒙版转换为索引。
以下函数用于实现此方法:
def check2din3d(b, a):
"""
Return where `b` (2D array) appears in `a` (3D array) along `axis=0`
"""
mydtype = np.dtype((np.void, a.dtype.itemsize*a.shape[1]*a.shape[2]))
aView = np.ascontiguousarray(a).flatten().view(mydtype)
bView = np.ascontiguousarray(b).flatten().view(mydtype)
return np.in1d(aView, bView)
考虑到@ayhan评论的更新时间表明,这种方法可以更快地在np.argwhere,但是差异并不显着,对于像下面这样的大型数组,@ ayhan的方法要快得多:
arrLarge = np.concatenate([arr]*10000000)
arrLarge = np.concatenate([arrLarge]*10, axis=2)
pattern = np.ascontiguousarray([[0,3]*10, [3,0]*10])
%timeit np.argwhere(np.all(arrLarger==pattern, axis=(1,2)))
#1 loops, best of 3: 2.99 s per loop
%timeit check2din3d(pattern, arrLarger)
#1 loops, best of 3: 4.65 s per loop