如何从包含最多3个值的2d numpy数组中获取列的索引

时间:2018-11-11 17:21:46

标签: python python-3.x numpy-ndarray

我有一个数组:

a = np.array([[22,11,44,33,66],
              [22,11,2,1,66],
              [1,11,44,22,4],
              [22,11,88,99,66]])

作为输出,我想要一个包含最大3个值的索引的数组作为2d数组。例如,上面的数组输出为:

array([[4,2,3],
       [4,0,1],
       [2,3,1],
       [3,2,4]])

1 个答案:

答案 0 :(得分:1)

要获取数组的前k个元素,请partition。由于分区通常为您提供k最低的元素,因此请使用反向索引:

k = 3
top = np.argpartition(a, -k, axis=1)[:, -k:]

如果您需要按降序对索引进行排序,请在结果中使用np.argsort

rows = np.arange(a.shape[0])[:, None]
s = np.argsort(a[rows, top], axis=1)[:, ::-1]
top = top[rows, s]
在使用rowstop进行华丽索引时,必须确保

s正确选择所有索引。每行的索引必须颠倒([:, ::-1])才能获得升序。