如何将numpy.argpartition的输出应用于二维数组?

时间:2014-10-12 05:38:12

标签: python arrays performance numpy indexing

我有一个较大的2d numpy数组,我想提取每行的最低10个元素及其索引。由于我的数组很大,我宁愿不对整个数组进行排序。

我听说过argpartition()函数,我可以用它获得最低10个元素的索引:

top10indexes = np.argpartition(myBigArray,10)[:,:10]

请注意argpartition()默认分区轴-1,这就是我想要的。此处的结果与myBigArray具有相同的形状,其中包含各个行的索引,以便前10个索引指向10个最低值。

我现在如何提取与这些索引相对应的myBigArray元素?

myBigArray[top10indexes]myBigArray[:,top10indexes]这样明显的花哨索引做了很多不同的事情。我也可以使用列表推导,例如:

array([row[idxs] for row,idxs in zip(myBigArray,top10indexes)])

但这会导致性能损失迭代numpy行并将结果转换回数组。

nb:我可以使用np.partition()来获取值,它们甚至可能对应于索引(或者可能不是......),但如果我不想做两次分区可以避免它。

1 个答案:

答案 0 :(得分:6)

您可以通过执行以下操作来避免使用拼合副本以及提取所有值的需要:

num = 10
top = np.argpartition(myBigArray, num, axis=1)[:, :num]
myBigArray[np.arange(myBigArray.shape[0])[:, None], top]

对于NumPy> = 1.9.0,这将非常有效并且与np.take()相当。