我怎么写NumPy argmode()?

时间:2016-01-17 03:38:26

标签: python math numpy

我知道argmax()会返回沿轴的最大值索引。

我也知道,在多次出现最大值的情况下,会返回与第一次出现相对应的索引。

当您想要找到最大值及其索引时,

argmax()可以正常工作。如何编写numpy.argmode()函数?

换句话说,如何在numpy数组中计算模式值并获取第一次出现的索引的函数是如何编写的?

只是所以每个人都知道没有numpy.argmode,但这个功能的功能是我所寻求的。

据我所知,该模式会多次出现。我们应该能够让它像argmax一样运行,如果我们有多次出现,它只返回第一次出现的值和索引。

我想要的一个例子是:

a = numpy.array([ 6, 3, 4, 1, 2, 2, 2])
numberIWant = numpy.argmode(a)
print(numberIWant)
# should print 4 (the index of the first occurrence of the mode)

我尝试使用:

stats.mode(a)[0][0]
numpy.argwhere(a==num)[0][0]

这确实有效,但我正在寻找一种更有效,更简洁的解决方案。 有任何想法吗?

2 个答案:

答案 0 :(得分:3)

如果你想留在NumPy中,你可以使用np.unique的一些额外回报来获得你想要的东西:

>>> _, idx, cnt = np.unique(a, return_index=True, return_counts=True)
>>> idx[np.argmax(cnt)]
4

修改

要提供有关正在发生的事情的一些背景信息... np.unique始终返回已排序的唯一值数组。可选的return_index提供另一个输出数组,其中包含每个唯一值的第一次出现的索引。并且可选的return_counts提供额外的输出,其中包含每个唯一值的出现次数。使用这些构建块,您需要做的就是在最高计数发生的位置返回索引数组的项。

答案 1 :(得分:2)

是什么让一个解决方案比另一个解决方案更“优雅”?短促?速度?聪明呢?大多数Pythonic? numpy的-ONIC?

对我而言,速度更紧凑,更紧凑。我可以通过将其包装在函数调用中来使解决方案更加紧凑。实际上,稳健性更为重要。

非笨拙的路线是使用collections中的便捷工具,如下所示:

In [342]: a = numpy.array([ 6, 3, 4, 1, 2, 2, 2])

In [343]: import collections

使用Counter快速获取模式(值):

In [344]: c=collections.Counter(a)
In [345]: c
Out[345]: Counter({2: 3, 1: 1, 3: 1, 4: 1, 6: 1})
In [347]: mode=c.most_common(1)[0][0]
In [348]: mode
Out[348]: 2

使用defaultdict收集所有值的位置:

In [349]: adict=collections.defaultdict(list)
In [350]: for i,v in enumerate(a):
    adict[v].append(i)
In [351]: adict[mode]
Out[351]: [4, 5, 6]

我可以在adict搜索最长的列表,但我怀疑Counter更快。

实际上,当我知道mode时,我需要的只是where - 正如您使用stats所示:

In [352]: np.where(a==mode)
Out[352]: (array([4, 5, 6], dtype=int32),)

在这个小阵列的时间测试中,Counter获胜。

In [358]: timeit stats.mode(a)[0][0]
1000 loops, best of 3: 337 µs per loop
In [359]: timeit collections.Counter(a).most_common(1)[0][0]
10000 loops, best of 3: 20 µs per loop

另一种可能的工具是bincount

In [367]: np.bincount(a)
Out[367]: array([0, 1, 3, 1, 1, 0, 1], dtype=int32)
In [368]: timeit np.argmax(np.bincount(a))
100000 loops, best of 3: 3.29 µs per loop

并使用where

In [373]: timeit np.where(a==np.argmax(np.bincount(a)))[0][0]
100000 loops, best of 3: 11.2 µs per loop

这很快,但我不确定它是否足够通用。