我有一个3d numpy数组,我获得满足特定条件的索引,例如:
a = np.tile([[1,2],[3,4]],(2,2,2))
indices = np.where(a == 2)
对于这个索引,我需要应用一个偏移量,例如(0,0,1),并查看是否满足另一个条件。
这样的事情:
offset = [0, 0, 1]
indices_shift = indices + offset
count = 0
for i in indices_shift:
if a[i] == 3:
count += 1
在此示例中,偏移量为(0,0,1)时,索引如下所示:
indices = (array([0, 0, 0, 0, 1, 1, 1, 1], dtype=int64), array([0, 0, 2, 2, 0, 0, 2, 2], dtype=int64), array([1, 3, 1, 3, 1, 3, 1, 3], dtype=int64))
我认为添加偏移结果应该是这样的:
indices_shift = (array([0, 0, 0, 0, 1, 1, 1, 1], dtype=int64), array([0, 0, 2, 2, 0, 0, 2, 2], dtype=int64), array([2, 4, 2, 4, 2, 4, 2, 4], dtype=int64))
有没有简单的方法呢?
感谢。
答案 0 :(得分:1)
这是一种方法 -
idx = np.argwhere(a == 2)+[0,0,1]
valid_mask = (idx< a.shape).all(1)
valid_idx = idx[valid_mask]
count = np.count_nonzero(a[tuple(valid_idx.T)] == 3)
步骤:
获取针对2
的匹配的索引。在这里使用np.argwhere
来获取一个漂亮的2D
数组,每列代表一个轴。另一个好处是,这使得它可以处理具有通用维数的数组。然后,以广播方式添加offset
。这是idx
。
idx
中的索引中,几乎没有超出数组形状的无效索引。因此,获得一个有效的掩码valid_mask
,从而获得有效的索引valid_idx
。
最后用这些索引到输入数组中,与3
进行比较并计算匹配数。