Numpy ndarray按行排序,然后排除索引

时间:2017-11-29 08:18:05

标签: python sorting numpy multidimensional-array

我有一个numpy.ndarray如下:

from numpy import array
a = array( [[1,1,0.4], [1,1,0.3],[0.4,0.3,1]] )

array([[ 1. ,  1. ,  0.4],
       [ 1. ,  1. ,  0.3],
       [ 0.4,  0.3,  1. ]])

以下是专栏:

dataidx = array( [1,2,3] )

我想按行对上面的数组进行排序,然后指定相关的dataidx:

indices = np.argsort(-a, axis=1)
result = np.hstack((dataidx[:, None], dataidx[indices]))
print(result)
[[1 1 2 3]
 [2 1 2 3]
 [3 3 1 2]]

对于每一行,如何根据第一列排除dataidx,如下所示?

[[1 2 3]
 [2 1 3]
 [3 1 2]]

1 个答案:

答案 0 :(得分:1)

这是一种方式 -

In [56]: m = result.shape[0]

In [57]: mask = np.c_[[True]*m,result[:,1:] != result[:,0,None]]

In [58]: result[mask].reshape(m,-1)
Out[58]: 
array([[1, 2, 3],
       [2, 1, 3],
       [3, 1, 2]])

这是另一个 -

In [105]: rm_idx = (result[:,1:] == result[:,0,None]).argmax(1)+1

In [106]: mask = np.ones(result.shape, dtype=bool)

In [107]: mask[np.arange(len(rm_idx)), rm_idx] = 0

In [108]: result[mask].reshape(result.shape[0],-1)
Out[108]: 
array([[1, 2, 3],
       [2, 1, 3],
       [3, 1, 2]])