如何获得多维数组的最小值索引的数组?

时间:2019-06-28 17:31:29

标签: python multidimensional-array indexing

我有一个形状为(2,2,3)的多维数组:

array([[[  0.64,   0.49,   2.56],
    [  7.84,  13.69,  21.16]],

   [[ 33.64,  44.89,  57.76],
    [ 77.44,  94.09, 112.36]]])

我想找到每一行的最小值索引。因此,在此示例中,有4个最小值:0.49、7.84、33.64和77.44。

要获得这些最小值的索引,我认为这会起作用:

idx_arr = np.unravel_index(np.argmin(my_array,axis=2),my_array.shape)

这将产生以下索引数组:

(array([[0, 0],
    [0, 0]]), array([[0, 0],
    [0, 0]]), array([[1, 0],
    [0, 0]]))

但是,如人们所见,最小值没有正确计算:

my_array[idx_arr]
array([[0.49, 0.64],
   [0.64, 0.64]])

我在那里想念什么?

1 个答案:

答案 0 :(得分:1)

argmin实际上正在正确计算值。但是您误解了np.unravel_index的期望。

来自文档:

  

将平面索引或平面索引数组转换为的元组   坐标数组。

要查看在此处提供期望的输出将接受哪种输入,我们需要集中注意要点:它将以非平面术语将平面数组转换为特定位置的正确坐标数组。本质上,期望的是期望点的坐标,就像将输入数组展平了。

import numpy as np
inp = np.array([[[  0.64,   0.49,   2.56],
    [  7.84,  13.69,  21.16]],

   [[ 33.64,  44.89,  57.76],
    [ 77.44,  94.09, 112.36]]])

idx = inp.argmin(axis=-1)
#Output:
array([[1, 0],
       [0, 0]], dtype=int64)

请注意,您无法直接发送此idx,因为它不能代表inp数组的扁平版本的正确坐标。

这看起来更像是以下内容:

flat_idx = np.arange(0, idx.size*inp.shape[-1], inp.shape[-1]) + idx.flatten()
#Output:
array([1, 3, 6, 9], dtype=int64)

我们可以看到unravel_index很高兴接受它。

temp = np.unravel_index(flat_idx, inp.shape)
#Output:
(array([0, 0, 1, 1], dtype=int64),
 array([0, 1, 0, 1], dtype=int64),
 array([1, 0, 0, 0], dtype=int64))

inp[temp]

输出:

array([ 0.49,  7.84, 33.64, 77.44])

此外,看看输出元组,我们可以注意到,自己重新创建相同的对象也不太困难。请注意,最后一个数组对应于idx的展平形式,而前两个数组本质上允许通过inp的前两个轴进行索引。


为此,我们可以以一种非常漂亮的方式实际使用unravel_index函数,如下所示:

real_idx = (*np.unravel_index(np.arange(idx.size), idx.shape), idx.flatten())
inp[real_idx]
#Output:
array([ 0.49,  7.84, 33.64, 77.44])