多维数组上的np.argmax,保持一些索引不变

时间:2015-06-21 13:39:26

标签: numpy multidimensional-array indexing argmax

我有一组二维情景,取决于两个整数索引,比如p1和p2,每个矩阵的形状相同。

然后我需要为每对(p1,p2)找到矩阵的最大值和这些最大值的索引。 一个微不足道的,尽管很慢的方法是做这样的事情

import numpy as np
import itertools
range1=range(1,10)
range2=range(1,20)

for p1,p2 in itertools.product(range1,range1):
    mat=np.random.rand(10,10)
    index=np.unravel_index(mat.argmax(), mat.shape)
    m=mat[index]
    print m, index

对于我的应用程序,不幸的是这太慢了,我想由于使用了double for循环。 因此,我试图将所有东西打包成一个四维数组(比如BigMatrix),其中前两个坐标是索引p1,p2,另外两个是矩阵的坐标。

np.amax命令

    >>res=np.amax(BigMatrix,axis=(2,3))
    >>res.shape
         (10,20)
    >>res[p1,p2]==np.amax(BigMatrix[p1,p2,:,:])
         True

按预期工作,因为它循环穿过2轴和3轴。我如何为np.argmax做同样的事情?请记住,速度很重要。

非常感谢您提前,

恩佐

1 个答案:

答案 0 :(得分:1)

这对我来说很有用Mat是大矩阵。

# flatten the 3 and 4 dimensions of Mat and obtain the 1d index for the maximum
# for each p1 and p2
index1d = np.argmax(Mat.reshape(Mat.shape[0],Mat.shape[1],-1),axis=2)

# compute the indices of the 3 and 4 dimensionality for all p1 and p2
index_x, index_y = np.unravel_index(index1d,Mat[0,0].shape)

# bring the indices into the right shape
index = np.array((index_x,index_y)).reshape(2,-1).transpose()

# get the maxima
max_val = np.amax(Mat,axis=(2,3)).reshape(-1)

# combine maxima and indices
sol = np.column_stack((max_val,index))

print sol