为什么numpy.linalg.norm在小尺寸数据被多次调用时会变慢?

时间:2018-04-16 21:38:42

标签: python performance numpy

import numpy as np
from datetime import datetime
import math

def norm(l):
    s = 0
    for i in l:
        s += i**2
    return math.sqrt(s)

def foo(a, b, f):
    l = range(a)
    s = datetime.now()
    for i in range(b):
        f(l)
    e = datetime.now()
    return e-s

foo(10**4, 10**5, norm)
foo(10**4, 10**5, np.linalg.norm)
foo(10**2, 10**7, norm)
foo(10**2, 10**7, np.linalg.norm)

我得到了以下输出:

0:00:43.156278
0:00:23.923239
0:00:44.184835
0:01:00.343875

似乎对于小型数据多次调用np.linalg.norm时,它的运行速度比我的norm函数慢。

原因是什么?

2 个答案:

答案 0 :(得分:3)

首先:datetime.now()不适合衡量效果,它包括壁挂时间,您可能只是在高优先级流程运行或Pythons时选择一个糟糕的时间(对于您的计算机) GC开始......,

Python中有专用的计时功能/模块:IPython / Jupyter中的内置timeit模块或%timeit以及其他几个外部模块(如perf,... )

让我们看看如果我在您的数据上使用这些内容会发生什么:

import numpy as np
import math

def norm(l):
    s = 0
    for i in l:
        s += i**2
    return math.sqrt(s)

r1 = range(10**4)
r2 = range(10**2)

%timeit norm(r1)
3.34 ms ± 150 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit np.linalg.norm(r1)
1.05 ms ± 3.92 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit norm(r2)
30.8 µs ± 1.53 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit np.linalg.norm(r2)
14.2 µs ± 313 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

对于短迭代而言,它的速度并不快,它仍然更快。但请注意,如果您已经拥有NumPy数组,那么NumPy函数的真正优势在于:

a1 = np.arange(10**4)
a2 = np.arange(10**2)

%timeit np.linalg.norm(a1)
18.7 µs ± 539 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit np.linalg.norm(a2)
4.03 µs ± 157 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

是的,现在它的速度要快得多。 18.7us vs. 1ms - 10000个元素快几百倍。这意味着您的示例中np.linalg.norm的大部分时间都用于将range转换为np.array

答案 1 :(得分:0)

你的方式正确

np.linalg.norm在小数组上的开销很高。在大型数组上,jit编译函数和np.linalg.norm都在内存瓶颈中运行,这在大多数时间进行简单乘法的函数上是预期的。

如果从另一个jitted函数调用jitted函数,它可能会被内联,这可以比numpy-norm函数带来更大的优势。

示例

import numba as nb
import numpy as np

@nb.njit(fastmath=True)
def norm(l):
    s = 0.
    for i in range(l.shape[0]):
        s += l[i]**2
    return np.sqrt(s)

<强>性能

r1 = np.array(np.arange(10**2),dtype=np.int32)
Numba:0.42µs
linalg:4.46µs

r1 = np.array(np.arange(10**2),dtype=np.int32)
Numba:8.9µs
linalg:13.4µs

r1 = np.array(np.arange(10**2),dtype=np.float64)
Numba:0.35µs
linalg:3.71µs

r2 = np.array(np.arange(10**4), dtype=np.float64)
Numba:1.4µs
linalg:5.6µs

衡量效果

  • 在测量之前调用jit-compiled函数一次(第一次调用时有静态编译开销)
  • 明确测量是否有效(因为小阵列保留在处理器缓存中,可能会有超出实际示例的RAM吞吐量的乐观结果,例如。example