使用np.argpartition索引多维数组中的值

时间:2017-03-16 11:23:08

标签: python numpy

我有一个像这样的数组:

>>> a = np.arange(60).reshape([3,4,5])
>>> a
array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]]])

我想沿其中一个维度检索前k个值。例如,我将选择k = 2并沿中间维度。

我尝试使用argpartition并且似​​乎做了正确的事情,但我在使用它的输出时无法从原始数组中检索值。以下是我如何使用argpartition:

>>> indices = np.argpartition(a, 2, axis=1)
>>> indices
array([[[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]]])

>>> indices[:,-2:,:]
array([[[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]]])

但我无法通过切片使用这些索引来获取值。

>>> a[:,indices[:,-2:,:],:].shape
(3, 3, 2, 5, 5)

我期待看到一个形状(3,2,5)的数组(因为我正在寻找沿中轴的前2位),我想这看起来像这样:

>>> magic_output
array([[[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]]])

如何使用argpartition

中的索引访问值

2 个答案:

答案 0 :(得分:2)

np.argpartition获得最小的k索引。因此,要获得最高k索引,我们需要沿所需轴使用否定输入数组。然后,我们需要使用这些索引使用NumPy's advanced-indexing索引到该轴并具有所需的输出。

因此,实施将是 -

k = 2
m,n = a.shape[0], a.shape[2]
idx = np.argpartition(-a,k,axis=1)[:,k-1::-1]
out = a[np.arange(m)[:,None,None], idx, np.arange(n)]

示例运行 -

1)输入:

In [180]: a
Out[180]: 
array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]]])

2)提议的代码:

In [206]: k = 2
     ...: m,n = a.shape[0], a.shape[2]
     ...: idx = np.argpartition(-a,k,axis=1)[:,k-1::-1]
     ...: out = a[np.arange(m)[:,None,None], idx, np.arange(n)]
     ...: 

3)检查中间结果和输出:

In [207]: idx
Out[207]: 
array([[[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]]])

In [208]: out
Out[208]: 
array([[[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]]])

答案 1 :(得分:1)

您不需要argpartition只需使用np.sort()沿第二轴对数组进行排序,然后从该轴中选择最后两项:

np.sort(a, 2)[:, -2:, ]

以下是关于数组的随机播放版本的示例:

In [15]: np.random.shuffle(a)

In [16]: a
Out[16]: 
array([[[40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]],

       [[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]]])

In [17]: np.sort(a, 2)[:, -2:, ]
Out[17]: 
array([[[50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]],

       [[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]]])