如果最后一个轴索引小于另一个2D数组

时间:2018-03-01 13:14:28

标签: python arrays numpy

我有一个形状为(m,n,p)的3D数组a和一个形状为(m,n)的二维数组idx。我希望a中最后一个轴索引小于idx中相应元素的所有元素都设置为0。

以下代码有效。我的问题是:有更有效的方法吗?

a = np.array([[[1, 2, 3],
               [4, 5, 6]],

              [[7, 8, 9],
               [10, 11, 12]],

              [[21, 22, 23],
               [25, 26, 27]]])
idx = np.array([[2, 1],
                [0, 1],
                [1, 1]])
for (i, j), val in np.ndenumerate(idx):
    a[i, j, :val] = 0

结果是

array([[[ 0,  0,  3],
        [ 0,  5,  6]],

       [[ 7,  8,  9],
        [ 0, 11, 12]],

       [[ 0, 22, 23],
        [ 0, 26, 27]]])

1 个答案:

答案 0 :(得分:3)

使用broadcasting创建3D蒙版,然后使用boolean-indexing -

分配零
mask = idx[...,None] > np.arange(a.shape[2])
a[mask] = 0

或者,我们也可以使用NumPy内置进行外部更大比较来获得该掩码 -

mask = np.greater.outer(idx, np.arange(a.shape[2]))

运行给定样本 -

In [34]: mask = idx[...,None] > np.arange(a.shape[2])

In [35]: a[mask] = 0

In [36]: a
Out[36]: 
array([[[ 0,  0,  3],
        [ 0,  5,  6]],

       [[ 7,  8,  9],
        [ 0, 11, 12]],

       [[ 0, 22, 23],
        [ 0, 26, 27]]])