为什么numpy不会在非连续阵列上短路?

时间:2019-08-04 11:33:31

标签: python numpy short-circuiting

请考虑以下简单测试:

import numpy as np
from timeit import timeit

a = np.random.randint(0,2,1000000,bool)

让我们找到第一个True的索引

timeit(lambda:a.argmax(), number=1000)
# 0.000451055821031332

这是相当快的,因为numpy短路。

它也适用于连续切片,

timeit(lambda:a[1:-1].argmax(), number=1000)
# 0.0006490410305559635

但是,似乎不连续的不是。我主要对查找最后一个True感兴趣:

timeit(lambda:a[::-1].argmax(), number=1000)
# 0.3737605109345168
  

更新:我认为观察到的减速是由于没有短路造成的,这是不准确的(感谢@Victor Ruiz)。确实,在   所有False数组的最坏情况

b=np.zeros_like(a)
timeit(lambda:b.argmax(), number=1000)
# 0.04321779008023441
  

我们仍然比不连续的要快一个数量级   案件。我准备接受维克多的解释,即真正的罪魁祸首   是正在制作的副本(使用.copy()强制进行副本的时间是   暗示)。之后,是否真的不再重要   短路是否发生。

但是其他步长!= 1也会产生类似的行为。

timeit(lambda:a[::2].argmax(), number=1000)
# 0.19192566303536296

问题:在最后两个示例中,numpy为什么不短路 UPDATE 而无需复制

而且,更重要的是:是否有一种解决方法,即可以强制numpy伪造 UPDATE 而不复制副本,连续的数组?

2 个答案:

答案 0 :(得分:10)

问题与使用跨步时数组的内存对齐有关。 a[1:-1]a[::-1]被认为在内存中对齐,但是a[::2] 不要:

a = np.random.randint(0,2,1000000,bool)

print(a[1:-1].flags.c_contiguous) # True
print(a[::-1].flags.c_contiguous) # False
print(a[::2].flags.c_contiguous) # False

这说明了np.argmaxa[::2]上运行缓慢的原因(来自ndarrays的文档):

  

NumPy中的几种算法可用于任意跨距的数组。但是,某些算法需要单段数组。将不规则步距的数组传递给此类算法时,将自动创建一个副本。

np.argmax(a[::2])正在复制数组。因此,如果您执行timeit(lambda: np.argmax(a[::2]), number=5000),则将计时5000个数组a的副本

执行此操作,并比较这两个计时调用的结果:

print(timeit(lambda: np.argmax(a[::2]), number=5000))

b = a[::2].copy()
print(timeit(lambda: np.argmax(b), number=5000))

编辑: 深入研究numpy C语言的源代码,我发现了argmax函数PyArray_ArgMax的下划线实现,该函数在某个时候调用PyArray_ContiguousFromAny以确保给定的输入数组在内存中对齐(C风格)

然后,如果数组的dtype为bool,则它将委托给BOOL_argmax函数。 看一下它的代码,似乎总是被应用了。

摘要

  • 为了避免被np.argmax复制,请确保输入数组在内存中是连续的
  • 当数据类型为布尔值时,总是会发生短路。

答案 1 :(得分:2)

我对解决这个问题感兴趣。因此,我提供了下一个解决方案,该解决方案可以避免由于a[::-1]内部ndarray复制而导致的“ np.argmax”问题:

我创建了一个小型库,该库实现了np.argmax的包装函数argmax,但是当输入参数是步幅值设置为-1的一维布尔数组时,它的性能提高了:

https://github.com/Vykstorm/numpy-bool-argmax-ext

在这种情况下,它使用低级C例程查找从数组{{的末尾到开头}的最大值(k)的项的索引True。 1}}。
然后您可以使用a

计算argmax(a[::-1])

低级方法不执行任何内部ndarray副本,因为它与数组len(a)-k-1一起运行,该数组已经C连续并且在内存中对齐。它还适用于短路


编辑: 在处理与-1不同的步幅值(使用1D布尔数组)时,我也扩展了库以提高性能a,从而获得了不错的效果:argmaxa[::2],等等。

尝试一下。

相关问题