我有一个很大的ndimensional数组。我想迭代它来检查条件是否在本地满足。下一个片段解释了我的问题。
a = np.random.randint(2, size=(60,80,3,3))
test = np.array([[1,0,0],[0,1,0],[0,0,0]])
for i in xrange(a.shape[0]):
for j in xrange(b.shape[1]):
if (a[i,j] == test).all():
# Do something with indices i and j
代码显然很慢。我尝试使用numpy.where
,但它没有工作,因为它在四个索引中的每一个都寻找相等。
编辑:我还需要存储满足条件的索引(i,j)
答案 0 :(得分:1)
np.apply_over_axes(np.prod, a == test, [3,2]) == 1
为您提供一个大小为(60,80,1,1)
的数组,只要条件成立,它就是True
。线程启动程序找到的更短,更优选的版本是
(a == test).all(axis=(2,3))
两者都是等价的,但后者避免了布尔→整数→布尔转换。在该数组上使用np.where
来获取索引(i, j)
。