使用沿一个轴应用的argmin索引对多维数组建立索引

时间:2017-05-09 16:05:53

标签: python numpy

我有两个多维数组,让我们说x和y都有5个维度,我想找到x的值,y的最后一个分量是最小值。 要查找索引,我只使用I=argmin(y,axis=-1),这将返回一个4维索引数组。我该怎么做才能找到这些索引的x值?某种x[I]

2 个答案:

答案 0 :(得分:3)

方法#1:基本advanced-indexing扩展到5D个案。为了使事情更方便,我们可以使用np.ogrid的开放范围数组,然后执行advanced-indexing,就像这样 -

d0,d1,d2,d3,d4 = x.shape
s0,s1,s2,s3 = np.ogrid[:d0,:d1,:d2,:d3]
ymin = y[s0,s1,s2,s3,I]
xmin = x[s0,s1,s2,s3,I]

方法#2:我们可以通过将前两个步骤与np.ix_合并来缩短它,因此有一个通用函数来处理通用维数的ndarray -

indxs = np.ix_(*[np.arange(i) for i in x.shape[:-1]]) + (I,)
ymin = y[indxs]
xmin = x[indxs]

让我们使用一些样本随机值数组并通过直接计算min沿着最后一个轴y.min(axis=-1),即y.min(-1)并将其与索引值{{1}进行比较来验证来自建议的代码 -

ymin

答案 1 :(得分:1)

argmin与1或2d数组一起使用非常简单,但如果使用3或更多,则映射难以理解:

In [332]: y=np.arange(24)
In [333]: np.random.shuffle(y)
In [334]: y=y.reshape(2,3,4)
In [335]: y
Out[335]: 
array([[[19, 12,  9, 21],
        [ 8, 13, 20, 17],
        [22, 11,  5,  1]],

       [[ 7,  2, 23, 16],
        [ 0, 10,  6,  4],
        [14, 18, 15,  3]]])

In [338]: I = np.argmin(y, axis=-1)
In [339]: I
Out[339]: 
array([[2, 0, 3],
       [1, 0, 3]], dtype=int32)
In [340]: np.min(y, axis=-1)
Out[340]: 
array([[9, 8, 1],
       [2, 0, 3]])

结果是(2,3),每个平面/行的一个索引。

I[0,0]表示y[i,j,I[i,j]]行中i,j是最小值。

因此,我们需要一种方法来生成i,j配对

In [345]: i,j = np.ix_(np.arange(2), np.arange(3))
In [346]: i
Out[346]: 
array([[0],
       [1]])
In [347]: j
Out[347]: array([[0, 1, 2]])

In [349]: y[i,j,I[i,j]]
Out[349]: 
array([[9, 8, 1],
       [2, 0, 3]])

或者缩短为:

In [350]: y[i,j,I]
Out[350]: 
array([[9, 8, 1],
       [2, 0, 3]])

即使使用2d,方法也是一样的:

In [360]: z=y[:,:,1]
In [361]: z
Out[361]: 
array([[12, 13, 11],
       [ 2, 10, 18]])
In [362]: idx=np.argmin(z, axis=-1)
In [363]: idx
Out[363]: array([2, 0], dtype=int32)
In [364]: z[[0,1], idx]       # index the 1st dim with range
Out[364]: array([11,  2])

使用mgrid可以更容易地显示流程:

In [378]: i,j =np.mgrid[0:2,0:3]
In [379]: i
Out[379]: 
array([[0, 0, 0],
       [1, 1, 1]])
In [380]: j
Out[380]: 
array([[0, 1, 2],
       [0, 1, 2]])
In [381]: y[i, j, I]
Out[381]: 
array([[9, 8, 1],
       [2, 0, 3]])

此处ij是(2,3)数组,其形状匹配I。 3个数组一起选择y中的(2,3)元素数组。

ix_ogrid只生成等效的open数组。