如何有效切片numpy数组

时间:2017-03-23 03:02:09

标签: python arrays numpy numpy-broadcasting

我想以矩阵操作的方式实现以下代码,而不是使用for循环。

a = np.random.randint(0, 7, (4,3))
b = np.random.randint(0, 6, (4,3,2))
c = None
for idx in xrange(a.shape[0]):
     max_idx = np.argmax(a[idx])
     ex_b = b[idx, max_idx].reshape(1, -1)
     if c is None:
         c = ex_b
     else:
         c = np.concatenate((c, ex_b), axis=0)

基本上,我想首先在第二维中获得最大值的索引。然后我想根据这些索引在b中提取相应的第三维值。

例如:

a:
array([[5, 4, 1],
       [3, 1, 3],
       [4, 1, 2],
       [0, 0, 5]])
b:
array([[[1, 3], [1, 4], [5, 0]],
       [[2, 4], [2, 2], [1, 2]],
       [[2, 1], [1, 2], [4, 5]],
       [[4, 0], [5, 5], [0, 2]]])

然后np.argmax(a, axis=1)array([0, 0, 0, 2])  所以c[0] = b[0][0], c[1]=b[1]b[0], c[2]=b[2][0], c[3]=b[3][2]

我认为这个for循环会降低运行速度,是否有更优雅的方式以更快的方式实现这一目标?

1 个答案:

答案 0 :(得分:2)

您可以像这样使用花式索引:

b[np.arange(4), np.argmax(a, axis=-1)]
# array([[1, 3],
#        [2, 4],
#        [2, 1],
#        [0, 2]])