我有一个平面阵列b
:
a = numpy.array([0, 1, 1, 2, 3, 1, 2])
一个数组c
的索引标记每个" chunk的开头":
b = numpy.array([0, 4])
我知道我可以在每个" chunk"中找到最大值。使用减少:
m = numpy.maximum.reduceat(a,b)
>>> array([2, 3], dtype=int32)
但是......有没有办法在一个块<edit>
(如</edit>
)中找到最大numpy.argmax
的索引,使用矢量化操作(没有列表,循环)?
答案 0 :(得分:2)
从this post
借用这个想法。
涉及的步骤:
通过限制偏移量抵消组中的所有元素。对它们进行全局排序,从而限制每个组保持其位置,但对每个组中的元素进行排序。
在排序数组中,我们会查找最后一个元素,即最大组。他们的指数是在抵消群长后的argmax。
因此,矢量化实现将是 -
def numpy_argmax_reduceat(a, b):
n = a.max()+1 # limit-offset
grp_count = np.append(b[1:] - b[:-1], a.size - b[-1])
shift = n*np.repeat(np.arange(grp_count.size), grp_count)
sortidx = (a+shift).argsort()
grp_shifted_argmax = np.append(b[1:],a.size)-1
return sortidx[grp_shifted_argmax] - b
作为一个小调整,可能更快,我们也可以用shift
创建cumsum
,因此可以改变之前的方法,就像这样 -
def numpy_argmax_reduceat_v2(a, b):
n = a.max()+1 # limit-offset
id_arr = np.zeros(a.size,dtype=int)
id_arr[b[1:]] = 1
shift = n*id_arr.cumsum()
sortidx = (a+shift).argsort()
grp_shifted_argmax = np.append(b[1:],a.size)-1
return sortidx[grp_shifted_argmax] - b