我正在编写代码来优化依赖于可变数量参数的数量。为了优化,我想一次跨多个轴应用索引选择函数,如numpy.argmax和numpy.argmin。下面是我现在正在使用的代码。是否有更内置或更有效的方法在任意数量的轴上执行此任务,这些轴可能顺序也可能不顺序?
def nd_arg_axes(func, array, start):
"""Applies an index selecting function over trailing axes from start."""
n_trail = len(array.shape[start:]) # Number of trailing axes to apply to.
indices = np.zeros((n_trail,)+array.shape[:start], dtype=np.intp)
for i in np.ndindex(array.shape[:start]):
indices[(Ellipsis,)+i] = np.unravel_index(func(array[i]),
array.shape[start:])
return tuple(indices)
# Test showing nd_arg_axes does indeed return the correct indices.
array = np.arange(27).reshape(3,3,3)
max_js = nd_arg_axes(np.argmax, array, 1)
(array[tuple(np.indices(array))+max_js] ==
np.squeeze(np.apply_over_axes(np.amax, array, axes=[1,2])))
答案 0 :(得分:2)
如果要选择尾随轴,可以将尾随轴重新整形为-1,并将func应用于轴= -1:
def f(func, array, start):
shape = array.shape
tmp = array.reshape(shape[:start] + (-1,))
indices = func(tmp, axis=-1)
return np.unravel_index(indices, shape[start:])