请考虑以下简单测试:
import numpy as np
from timeit import timeit
a = np.random.randint(0,2,1000000,bool)
让我们找到第一个True
的索引
timeit(lambda:a.argmax(), number=1000)
# 0.000451055821031332
这是相当快的,因为numpy
短路。
它也适用于连续切片,
timeit(lambda:a[1:-1].argmax(), number=1000)
# 0.0006490410305559635
但是,似乎不连续的不是。我主要对查找最后一个True
感兴趣:
timeit(lambda:a[::-1].argmax(), number=1000)
# 0.3737605109345168
更新:我认为观察到的减速是由于没有短路造成的,这是不准确的(感谢@Victor Ruiz)。确实,在 所有
False
数组的最坏情况
b=np.zeros_like(a)
timeit(lambda:b.argmax(), number=1000)
# 0.04321779008023441
我们仍然比不连续的要快一个数量级 案件。我准备接受维克多的解释,即真正的罪魁祸首 是正在制作的副本(使用
.copy()
强制进行副本的时间是 暗示)。之后,是否真的不再重要 短路是否发生。
但是其他步长!= 1也会产生类似的行为。
timeit(lambda:a[::2].argmax(), number=1000)
# 0.19192566303536296
问题:在最后两个示例中,numpy
为什么不短路 UPDATE 而无需复制?
而且,更重要的是:是否有一种解决方法,即可以强制numpy
伪造 UPDATE 而不复制副本,连续的数组?
答案 0 :(得分:10)
问题与使用跨步时数组的内存对齐有关。
a[1:-1]
,a[::-1]
被认为在内存中对齐,但是a[::2]
不要:
a = np.random.randint(0,2,1000000,bool)
print(a[1:-1].flags.c_contiguous) # True
print(a[::-1].flags.c_contiguous) # False
print(a[::2].flags.c_contiguous) # False
这说明了np.argmax
在a[::2]
上运行缓慢的原因(来自ndarrays的文档):
NumPy中的几种算法可用于任意跨距的数组。但是,某些算法需要单段数组。将不规则步距的数组传递给此类算法时,将自动创建一个副本。
np.argmax(a[::2])
正在复制数组。因此,如果您执行timeit(lambda: np.argmax(a[::2]), number=5000)
,则将计时5000个数组a
的副本
执行此操作,并比较这两个计时调用的结果:
print(timeit(lambda: np.argmax(a[::2]), number=5000))
b = a[::2].copy()
print(timeit(lambda: np.argmax(b), number=5000))
编辑:
深入研究numpy C语言的源代码,我发现了argmax
函数PyArray_ArgMax的下划线实现,该函数在某个时候调用PyArray_ContiguousFromAny以确保给定的输入数组在内存中对齐(C风格)
然后,如果数组的dtype为bool,则它将委托给BOOL_argmax函数。 看一下它的代码,似乎总是被应用了。
np.argmax
复制,请确保输入数组在内存中是连续的答案 1 :(得分:2)
我对解决这个问题感兴趣。因此,我提供了下一个解决方案,该解决方案可以避免由于a[::-1]
内部ndarray复制而导致的“ np.argmax
”问题:
我创建了一个小型库,该库实现了np.argmax
的包装函数argmax
,但是当输入参数是步幅值设置为-1的一维布尔数组时,它的性能提高了:
https://github.com/Vykstorm/numpy-bool-argmax-ext
在这种情况下,它使用低级C例程查找从数组{{的末尾到开头}的最大值(k
)的项的索引True
。 1}}。
然后您可以使用a
argmax(a[::-1])
低级方法不执行任何内部ndarray副本,因为它与数组len(a)-k-1
一起运行,该数组已经C连续并且在内存中对齐。它还适用于短路
编辑:
在处理与-1不同的步幅值(使用1D布尔数组)时,我也扩展了库以提高性能a
,从而获得了不错的效果:argmax
,a[::2]
,等等。
尝试一下。