如何从python中的nd数组中提取特定行

时间:2018-07-23 12:33:03

标签: python arrays performance numpy indexing

我有一个数组b,其形状为E x B x 3。我还有另一个数组a,用于指定要使用b的3个元素。

以下代码有效(在本示例中为E=2, B=4):

import numpy as np

a = [1, 1, 0, 0]
b = np.array([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 3.0]],
            [[2.0, 0.0, 0.0], [0.0, 2.0, 0.0],  [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]]])
# n_pred = np.transpose(n_pred, axes=[1, 0, 2])
c = []
for i, idx in enumerate(a):
    c.append(b[idx, i])
c = np.array(c)
print(c)

我的问题是,有没有更有效的方法? (也许使用一些内置的numpy函数?

1 个答案:

答案 0 :(得分:1)

您可以按前两个维度编制索引:

c = b[a, range(len(a))]

print(c)

array([[ 2.,  0.,  0.],
       [ 0.,  2.,  0.],
       [ 0.,  0.,  1.],
       [ 0.,  0.,  3.]])