我有一个形状为(1000, 12, 30)
的3d数组,我有一个2d数组的形状列表(12, 30)
,我想要检查的是3d数组中是否存在这些2d数组。 Python中有一种简单的方法可以做到这一点吗?我尝试过关键字in
,但它不起作用。
答案 0 :(得分:3)
numpy
中有一种方式,您可以使用np.all
a=np.random.rand(3,1,2)
b=a[1][0]
np.all(np.all(a==b,1),1)
Out[612]: array([False, True, False])
来自bnaecker的解决方案
np.all(a == b, axis=(1,2))
如果只想检查退出
np.any(np.all(a == b, axis=(1,2)))
答案 1 :(得分:3)
这是一种快速方法(以前used by @DanielF以及as @jaime和其他人,毫无疑问)使用技巧从短路中获益:将模板大小的块转换为dtype void
的单个元素。当比较两个这样的块时,在第一个差异之后numpy停止,产生巨大的速度优势。
>>> def in_(data, template):
... dv = data.reshape(data.shape[0], -1).view(f'V{data.dtype.itemsize*np.prod(data.shape[1:])}').ravel()
... tv = template.ravel().view(f'V{template.dtype.itemsize*template.size}').reshape(())
... return (dv==tv).any()
示例:
>>> a = np.random.randint(0, 100, (1000, 12, 30))
>>> check = a[np.random.randint(0, 1000, (10,))]
>>> check += np.random.random(check.shape) < 0.001
>>>
>>> [in_(a, c) for c in check]
[True, True, True, False, False, True, True, True, True, False]
# compare to other method
>>> (a==check[:, None]).all((-1,-2)).any(-1)
array([ True, True, True, False, False, True, True, True, True,
False])
与“直接”numpy方法相同,但速度提高了近20倍:
>>> from timeit import timeit
>>> kwds = dict(globals=globals(), number=100)
>>>
>>> timeit("(a==check[:, None]).all((-1,-2)).any(-1)", **kwds)
0.4793281531892717
>>> timeit("[in_(a, c) for c in check]", **kwds)
0.026218891143798828
答案 2 :(得分:2)
鉴于
a = np.arange(12).reshape(3, 2, 2)
lst = [
np.arange(4).reshape(2, 2),
np.arange(4, 8).reshape(2, 2)
]
print(a, *lst, sep='\n{}\n'.format('-' * 20))
[[[ 0 1]
[ 2 3]]
[[ 4 5]
[ 6 7]]
[[ 8 9]
[10 11]]]
--------------------
[[0 1]
[2 3]]
--------------------
[[4 5]
[6 7]]
请注意,lst
是根据OP的数组列表。我将在下面制作一个3d数组b
。
使用广播。使用广播规则。我希望a
的维度为(1, 3, 2, 2)
和b
为(2, 1, 2, 2)
。
b = np.array(lst)
x, *y = b.shape
c = np.equal(
a.reshape(1, *a.shape),
np.array(lst).reshape(x, 1, *y)
)
我将使用all
生成(2, 3)
个真值数组,并np.where
找出a
和b
个子数组中的哪一个实际上是平等的。
i, j = np.where(c.all((-2, -1)))
这只是我们实现了目标的验证。我们应该观察到,对于每个配对的i
和j
值,子数组实际上是相同的。
for t in zip(i, j):
print(a[t[0]], b[t[1]], sep='\n\n')
print('------')
[[0 1]
[2 3]]
[[0 1]
[2 3]]
------
[[4 5]
[6 7]]
[[4 5]
[6 7]]
------
in
然而,要完成OP使用in
a_ = a.tolist()
list(filter(lambda x: x.tolist() in a_, lst))
[array([[0, 1],
[2, 3]]), array([[4, 5],
[6, 7]])]