使用numpy.argsort的输出对numpy.ndarray的每一列进行排序

时间:2017-04-04 22:14:46

标签: python arrays sorting numpy

我想根据先前处理的参考数组对numpy 2D数组进行排序。 我的想法是存储我的引用数组的numpy.argsort输出并使用它来对其他数组进行排序:

In [13]: # my reference array
    ...: ref_arr = np.random.randint(10, 30, 12).reshape(3, 4)
Out[14]:
array([[12, 22, 12, 13],
       [28, 26, 21, 23],
       [24, 14, 16, 25]])

# desired output:
array([[12, 14, 12, 13],
       [24, 22, 16, 23],
       [28, 26, 21, 25]])

我尝试了什么:

In [15]: # store the sorting matrix
    ...: sm = np.argsort(ref_arr, axis=0)
Out[16]:
array([[0, 2, 0, 0],
       [2, 0, 2, 1],
       [1, 1, 1, 2]])

但不幸的是,最后一步只适用于一维数组:

In [17]: ref_arr[sm]
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-17-48b785178465> in <module>()
----> 1 ref_arr[sm]

IndexError: index 3 is out of bounds for axis 0 with size 3

我发现这个Github issue是针对这个问题而创建的,但不幸的是,它通过提到我尝试过的仅用于1D数组来解决。

In a comment to this issue提到了一个类似于我的问题的例子。该代码段无法解决我的问题,因为它按而非对数组进行排序。但它暗示了我必须朝哪个方向移动......

a[np.arange(np.shape(a)[0])[:,np.newaxis], np.argsort(a)]

不幸的是,我不明白这个例子足以让它适应我的用例。也许有人可以解释这种高级索引如何在这里工作?这可能使我能够自己解决这个问题但是我不介意一个交钥匙解决方案。 ;)

谢谢。

以防万一:我在OS X上使用Python 3.6.1和numpy 1.12.1。

2 个答案:

答案 0 :(得分:4)

基本上需要两个步骤:

1]使用axis=0 -

获取每个col的argsort索引
sidx = ref_arr.argsort(axis=0)

2]使用advanced-indexing使用sidx来选择行,即索引到第一个维度,并使用另一个范围数组索引到第二个维度,以便覆盖sidx所有列的索引 -

out = ref_arr[sidx, np.arange(sidx.shape[1])]

示例运行 -

In [185]: ref_arr
Out[185]: 
array([[12, 22, 12, 13],
       [28, 26, 21, 23],
       [24, 14, 16, 25]])

In [186]: sidx = ref_arr.argsort(axis=0)

In [187]: sidx
Out[187]: 
array([[0, 2, 0, 0],
       [2, 0, 2, 1],
       [1, 1, 1, 2]])

In [188]: ref_arr[sidx, np.arange(sidx.shape[1])]
Out[188]: 
array([[12, 14, 12, 13],
       [24, 22, 16, 23],
       [28, 26, 21, 25]])

答案 1 :(得分:1)

截至2018年5月,可以使用np.take_along_axis

np.take_along_axis(ref_arr, sm, axis=0)
Out[25]: 
array([[10, 16, 15, 10],
       [13, 23, 24, 12],
       [28, 26, 28, 28]])