我有一个numpy数组,除了第一个,我想把argmax放在所有轴上。我(我认为)有一个解决方案,但我想知道是否有更好的方法来实现它。
import numpy as np
def argmax(array):
## Argmax along all axes except the first (ie axis 0)
last_axis = len(array.shape) - 1
if last_axis == 0:
return tuple(range(array.size))
if last_axis == 1:
return (range(array.shape[0]), list(np.argmax(array, axis=1)))
index_array = np.argmax(array, axis=last_axis)
smaller_array = np.amax(array, axis=last_axis)
assert index_array.shape == smaller_array.shape
argmax_smaller_array = argmax(smaller_array)
return argmax_smaller_array + (list(index_array[argmax_smaller_array]), )
一些例子:
a = np.arange(12).reshape((6, 2))
a[5, 0] = 22
argmax(a)
a[argmax(a)]
b = np.arange(18).reshape((3, 3, 2))
b[0, 0, 0] = 55
b[argmax(b)]
np.all(b[argmax(b)] == np.array([np.max(b[0]), np.max(b[1]), np.max(b[2])])) # True
我刚开始玩numpy,我想知道是否有更简单的方法来做到这一点。我在改写已存在的东西吗?
答案 0 :(得分:1)
您的方法似乎没问题,但您计算了许多您不需要的中间结果。你可以这样做:
import numpy as np
def argmax(array):
shape = array.shape
array = array.reshape((shape[0], -1))
ravelmax = np.argmax(array, axis=1)
return (np.arange(shape[0]),) + np.unravel_index(ravelmax, shape[1:])