numpy:索引问题

时间:2015-01-09 14:23:53

标签: python arrays numpy

我遇到了一些问题,我对一个numpy数组进行排序并获取排序数组的索引,但将索引应用于原始数组并不能达到我的预期。所以,这是我正在做的测试案例:

 import numpy as np

 # Two 3x3 matrices
 x = np.random.rand(2, 3, 3)
 # Perform some decomposition (Never mind the matrices are not hermitian...) 
 evals, evecs = np.linalg.eigh(x)

 # evals has shape (2, 3), evecs has shape (2, 3, 3)
 indices = evals.argsort(axis=1)[..., ::-1] # Do descending sort

 # Now I want to apply the index to evals.
 evals = evals[:, indices]

我得到一个(2,3,3)数组,而不是返回一个(2,3)数组,而这些数组将被复制。类似的东西:

array([[[ 1.15628047,  0.16853886, -0.28607138],
        [ 1.15628047,  0.16853886, -0.28607138]],

       [[ 2.4311532 , -0.00754817, -0.24086572],
        [ 2.4311532 , -0.00754817, -0.24086572]]])

我不确定为什么会这样。非常感谢任何帮助。

1 个答案:

答案 0 :(得分:1)

这应该有效:

import numpy as np
idx0 = np.arange(evals.shape[0])[:,np.newaxis]
idx1 = evals.argsort(1)[...,::-1]
evals[idx0,idx1]

这会按顺序逐个对每一行进行排序。

编辑: 在这种情况下,您需要(idx0,idx1)来进一步处理特征向量evecs。如果不是这样的话,那就很容易做到

evals.sort()
evals = evals[:,::-1]