在仅有少量元素的轴上加速ndarray.argmax

时间:2015-01-28 17:33:48

标签: python performance numpy

当它迭代的轴只有几个元素时,我试图加快ndarray.argmax()的性能(我也希望改善ndarray.max()的性能但是会满足于argmax()如果那是可能的话。数组的其他维度有很多元素。

以下是一个例子:

from numpy import *
r = random.rand(128,128,128,2)
o = empty(shape(r[...,0]), dtype=int)

## numpy .argmax()
r.argmax(axis=3, out=o)

## Using logic arguments
less(r[...,0], r[...,1], out=o)

在我的计算机上,第二个版本的速度大约是此示例的两倍。我有两个问题:1。有没有办法进一步提高速度? 2.这种方法是否可以扩展到具有3到5个元素的轴,同时仍保持性能优势?

1 个答案:

答案 0 :(得分:0)

那么,

o = r[...,0] < r[...,1]

比两者都快(在我的系统上)。

至于你的第二个问题,我不确定在你的第三个轴上比较超过2个值的情况下你想要什么输出。你能用一个例子来澄清吗? argmaxless执行不同的操作只是因为您的最后一个轴的维度为2,您在此处得到相同的结果。对于值的简单less / True比较,此解决方案将替换False,但argmax实际上会找到最大值的索引。