Python numpy性能 - 选择非常大的数组

时间:2017-08-10 10:35:42

标签: python arrays performance numpy indexing

我尝试优化某些代码,其中一项耗时的操作如下:

import numpy as np
survivors = np.where(a > 0)[0]
pos = len(survivors)
a[:pos] = a[survivors]
b[:pos] = b[survivors]
c[:pos] = c[survivors]

在我的代码中a是一个非常大的(超过100000个)NumPy浮点数组。他们中的许多人将是0.

有没有办法加快速度?

1 个答案:

答案 0 :(得分:2)

据我所知,没有任何东西可以用纯粹的NumPy加快速度。但是,如果你有numba,你可以编写自己的版本"选择"使用jitted函数:

import numba as nb

@nb.njit
def selection(a, b, c):
    insert_idx = 0
    for idx, item in enumerate(a):
        if item > 0:
            a[insert_idx] = a[idx]
            b[insert_idx] = b[idx]
            c[insert_idx] = c[idx]
            insert_idx += 1

在我的测试运行中,这比你的NumPy代码快了大约2倍。但是,如果您不使用conda,那么numba可能会非常依赖。

实施例

>>> import numpy as np
>>> a = np.array([0., 1., 2., 0.])
>>> b = np.array([1., 2., 3., 4.])
>>> c = np.array([1., 2., 3., 4.])
>>> selection(a, b, c)
>>> a, b, c
(array([ 1.,  2.,  2.,  0.]),
 array([ 2.,  3.,  3.,  4.]),
 array([ 2.,  3.,  3.,  4.]))

定时:

很难准确地计时,因为所有方法都可以就地工作,所以我实际上使用timeit.repeat来测量时间number=1(这可以避免由于in而导致时间损坏) - 解决方案的位置)我使用了结果列表中的min,因为它被宣传为文档中最有用的量化指标:

  

注意

     

从结果向量计算平均值和标准偏差并报告这些是很诱人的。但是,这不是很有用。在典型情况下,最低值给出了机器运行给定代码段的速度的下限;结果向量中较高的值通常不是由Python的速度变化引起的,而是由于其他过程干扰您的计时准确性。因此结果的min()可能是您应该感兴趣的唯一数字。之后,您应该查看整个向量并应用常识而不是统计。

Numba解决方案
import timeit

min(timeit.repeat("""selection(a, b, c)""",
              """import numpy as np
from __main__ import selection

a = np.arange(1000000) % 3
b = a.copy()
c = a.copy()
""", repeat=100, number=1))
  

0.007700118746939211

原始解决方案

import timeit

min(timeit.repeat("""survivors = np.where(a > 0)[0]
pos = len(survivors)
a[:pos] = a[survivors]
b[:pos] = b[survivors]
c[:pos] = c[survivors]""",
              """import numpy as np
a = np.arange(1000000) % 3
b = a.copy()
c = a.copy()
""", repeat=100, number=1))
  

0.02862214457188​​3723

Alexander McFarlane的解决方案(现已删除)

import timeit

min(timeit.repeat("""survivors = comb_array[:, 0].nonzero()[0]
comb_array[:len(survivors)] = comb_array[survivors]""",
              """import numpy as np
a = np.arange(1000000) % 3
b = a.copy()
c = a.copy()

comb_array = np.vstack([a,b,c]).T""", repeat=100, number=1))
  

0.058305527038669425

因此,Numba解决方案实际上可以将此速度提高3-4倍,而Alexander McFarlane的解决方案实际上比原始方法更慢(2倍)。然而,repeat s的数量较少可能会对时间产生偏差。