对从np.where获得的索引应用偏移量

时间:2017-07-18 10:16:57

标签: python numpy

我有一个3d numpy数组,我获得满足特定条件的索引,例如:

a = np.tile([[1,2],[3,4]],(2,2,2))
indices = np.where(a == 2)

对于这个索引,我需要应用一个偏移量,例如(0,0,1),并查看是否满足另一个条件。

这样的事情:

offset = [0, 0, 1]
indices_shift = indices + offset

count = 0
for i in indices_shift:
    if a[i] == 3:
       count += 1

在此示例中,偏移量为(0,0,1)时,索引如下所示:

indices = (array([0, 0, 0, 0, 1, 1, 1, 1], dtype=int64), array([0, 0, 2, 2, 0, 0, 2, 2], dtype=int64), array([1, 3, 1, 3, 1, 3, 1, 3], dtype=int64))

我认为添加偏移结果应该是这样的:

indices_shift = (array([0, 0, 0, 0, 1, 1, 1, 1], dtype=int64), array([0, 0, 2, 2, 0, 0, 2, 2], dtype=int64), array([2, 4, 2, 4, 2, 4, 2, 4], dtype=int64))

有没有简单的方法呢?

感谢。

1 个答案:

答案 0 :(得分:1)

这是一种方法 -

idx = np.argwhere(a == 2)+[0,0,1]
valid_mask = (idx< a.shape).all(1)
valid_idx = idx[valid_mask]
count = np.count_nonzero(a[tuple(valid_idx.T)] == 3)

步骤:

  • 获取针对2的匹配的索引。在这里使用np.argwhere来获取一个漂亮的2D数组,每列代表一个轴。另一个好处是,这使得它可以处理具有通用维数的数组。然后,以广播方式添加offset。这是idx

  • idx中的索引中,几乎没有超出数组形状的无效索引。因此,获得一个有效的掩码valid_mask,从而获得有效的索引valid_idx

  • 最后用这些索引到输入数组中,与3进行比较并计算匹配数。