我尝试优化某些代码,其中一项耗时的操作如下:
import numpy as np
survivors = np.where(a > 0)[0]
pos = len(survivors)
a[:pos] = a[survivors]
b[:pos] = b[survivors]
c[:pos] = c[survivors]
在我的代码中a
是一个非常大的(超过100000个)NumPy浮点数组。他们中的许多人将是0.
有没有办法加快速度?
答案 0 :(得分:2)
据我所知,没有任何东西可以用纯粹的NumPy加快速度。但是,如果你有numba,你可以编写自己的版本"选择"使用jitted函数:
import numba as nb
@nb.njit
def selection(a, b, c):
insert_idx = 0
for idx, item in enumerate(a):
if item > 0:
a[insert_idx] = a[idx]
b[insert_idx] = b[idx]
c[insert_idx] = c[idx]
insert_idx += 1
在我的测试运行中,这比你的NumPy代码快了大约2倍。但是,如果您不使用conda
,那么numba可能会非常依赖。
>>> import numpy as np
>>> a = np.array([0., 1., 2., 0.])
>>> b = np.array([1., 2., 3., 4.])
>>> c = np.array([1., 2., 3., 4.])
>>> selection(a, b, c)
>>> a, b, c
(array([ 1., 2., 2., 0.]),
array([ 2., 3., 3., 4.]),
array([ 2., 3., 3., 4.]))
很难准确地计时,因为所有方法都可以就地工作,所以我实际上使用timeit.repeat
来测量时间number=1
(这可以避免由于in而导致时间损坏) - 解决方案的位置)我使用了结果列表中的min
,因为它被宣传为文档中最有用的量化指标:
注意
从结果向量计算平均值和标准偏差并报告这些是很诱人的。但是,这不是很有用。在典型情况下,最低值给出了机器运行给定代码段的速度的下限;结果向量中较高的值通常不是由Python的速度变化引起的,而是由于其他过程干扰您的计时准确性。因此结果的min()可能是您应该感兴趣的唯一数字。之后,您应该查看整个向量并应用常识而不是统计。
import timeit
min(timeit.repeat("""selection(a, b, c)""",
"""import numpy as np
from __main__ import selection
a = np.arange(1000000) % 3
b = a.copy()
c = a.copy()
""", repeat=100, number=1))
0.007700118746939211
import timeit
min(timeit.repeat("""survivors = np.where(a > 0)[0]
pos = len(survivors)
a[:pos] = a[survivors]
b[:pos] = b[survivors]
c[:pos] = c[survivors]""",
"""import numpy as np
a = np.arange(1000000) % 3
b = a.copy()
c = a.copy()
""", repeat=100, number=1))
0.028622144571883723
import timeit
min(timeit.repeat("""survivors = comb_array[:, 0].nonzero()[0]
comb_array[:len(survivors)] = comb_array[survivors]""",
"""import numpy as np
a = np.arange(1000000) % 3
b = a.copy()
c = a.copy()
comb_array = np.vstack([a,b,c]).T""", repeat=100, number=1))
0.058305527038669425
因此,Numba解决方案实际上可以将此速度提高3-4倍,而Alexander McFarlane的解决方案实际上比原始方法更慢(2倍)。然而,repeat
s的数量较少可能会对时间产生偏差。