Numpy argmax - 随意打破领带

时间:2017-02-06 15:38:16

标签: python numpy

numpy.argmax函数中,多个max元素之间的中断是为了返回第一个元素。 是否有随机化打破平局的功能,以便所有最大数字都有相同的被选中机会?

以下是直接来自numpy.argmax文档的示例。

>>> b = np.arange(6)
>>> b[1] = 5
>>> b
array([0, 5, 2, 3, 4, 5])
>>> np.argmax(b) # Only the first occurrence is returned.
1

我正在寻找方法,以便以相同的概率返回列表中的第1和第5个元素。

谢谢!

5 个答案:

答案 0 :(得分:22)

使用np.random.choice -

np.random.choice(np.flatnonzero(b == b.max()))

让我们验证一个包含三个最大候选者的数组 -

In [298]: b
Out[298]: array([0, 5, 2, 5, 4, 5])

In [299]: c=[np.random.choice(np.flatnonzero(b == b.max())) for i in range(100000)]

In [300]: np.bincount(c)
Out[300]: array([    0, 33180,     0, 33611,     0, 33209])

答案 1 :(得分:6)

对于多维数组,choice将不起作用。

另一种选择是

def randargmax(b,**kw):
  """ a random tie-breaking argmax"""
  return np.argmax(np.random.random(b.shape) * (b==b.max()), **kw)

如果由于某种原因生成随机浮点数比其他方法慢,random.random可以用其他方法替换。

答案 2 :(得分:4)

由于答案可能并不明显,因此它是这样工作的:

  • b == b.max()将返回一个布尔数组,对于最大项目,值为true;对于其他项目,值为false。
  • flatnonzero()将执行以下操作:忽略假值(非零部分),然后返回真值的索引。换句话说,您将获得一个数组,其中的项索引与最大值匹配
  • 最后,您从中随机选择这些索引

答案 3 :(得分:2)

最简单的方法是

np.random.choice(np.where(b == b.max())[0])

答案 4 :(得分:0)

除了@Manux的答案,

b.max()更改为np.amax(b,**kw, keepdims=True)将使您可以沿轴进行操作。

def randargmax(b,**kw):
    """ a random tie-breaking argmax"""
    return np.argmax(np.random.random(b.shape) * (b==b.max()), **kw)

randargmax(b,axis=None)