将argmax / argmin分组为numpy中的分区索引

时间:2014-03-02 05:42:23

标签: python numpy

Numpy的ufunc有一个reduceat方法,可以在数组中的连续分区上运行它们。所以不要写:

import numpy as np
a = np.array([4, 0, 6, 8, 0, 9, 8, 5, 4, 9])
split_at = [4, 5]
maxima = [max(subarray for subarray in np.split(a, split_at)]

我可以写:

maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))

两者都会返回切片a[0:4]a[4:5]a[5:10]中的最大值,为[8, 0, 9]

我想要一个类似的函数来执行argmax,并指出我只想在每个分区中使用单个最大索引:[3, 4, 5]以及a }和split_at(尽管索引5和9都获得了最后一组中的最大值),如

所返回的那样
np.hstack([0, split_at]) + [np.argmax(subarray) for subarray in np.split(a, split_at)]

我将在下面发布一个可能的解决方案,但是希望看到一个没有在组上创建索引的矢量化。

2 个答案:

答案 0 :(得分:1)

此解决方案涉及在上面的示例中构建组([0, 0, 0, 0, 1, 2, 2, 2, 2, 2])的索引。

group_lengths = np.diff(np.hstack([0, split_at, len(a)]))
n_groups = len(group_lengths)
index = np.repeat(np.arange(n_groups), group_lengths)

然后我们可以使用:

maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))
all_argmax = np.flatnonzero(np.repeat(maxima, group_lengths) == a)
result = np.empty(len(group_lengths), dtype='i')
result[index[all_argmax[::-1]]] = all_argmax[::-1]

[3, 4, 5]中获取result[::-1]确保我们获得第一个而不是每个组中的最后一个argmax。

这依赖于这样一个事实:花哨赋值中的最后一个索引确定了分配的值,@ seberg says one shouldn't rely on(以及result = all_argmax[np.unique(index[all_argmax], return_index=True)[1]]可以实现更安全的替代,其中涉及len(maxima) ~ n_groups的排序1}}元素)。

答案 1 :(得分:0)

受此问题的启发,我已将{argmin / max功能添加到numpy_indexed包中。这是相应的测试的样子。请注意,密钥可以是任何顺序(以及npi支持的任何类型):

def test_argmin():
    keys   = [2, 0, 0, 1, 1, 2, 2, 2, 2, 2]
    values = [4, 5, 6, 8, 0, 9, 8, 5, 4, 9]
    unique, amin = group_by(keys).argmin(values)
    npt.assert_equal(unique, [0, 1, 2])
    npt.assert_equal(amin,   [1, 4, 0])