Numba jitted len()比纯Python len()慢

时间:2019-12-18 09:37:23

标签: python numpy numba

我正在学习numba,遇到这种我不理解的“奇怪”行为。 我尝试使用以下代码(在iPython中,用于计时):

import numpy as np
import numba as nb

@nb.njit
def nb_len(seq):
    return len(seq)

def py_len(seq):
    return len(seq)

##
t = np.random.rand(1000)

%timeit nb_len(t)
%timeit py_len(t)

结果如下(实际上是第二次运行,因为编译了numba):

258 ns ± 1.37 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
137 ns ± 0.964 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

纯python版本的速度是numba版本的两倍。 我也尝试使用签名@nb.njit( nb.int32(nb.float64[:]) ),但结果仍然相同。

我在某个地方犯错了吗?

谢谢。

2 个答案:

答案 0 :(得分:4)

添加时间的不是len()部分。使用输入参数调用jit函数会增加开销,这就是您所看到的时差。

bulk.execute();

带有输入参数的结果

import numba as nb

def py_pass(i):
    return i

@nb.njit()
def nb_pass(i):
    return i

%timeit py_pass(1)
%timeit nb_pass(1)

有趣的是,如果您不需要将任何东西传递给jit函数,它会更快:

102 ns ± 0.371 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
165 ns ± 0.783 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

没有输入参数的结果

def py_pass():
    return 1

@nb.njit()
def nb_pass():
    return 1

%timeit py_pass()
%timeit nb_pass()

答案 1 :(得分:2)

正如other answer所说的,这不是因为使用len函数,而是因为对numba函数的调用实际上比对普通Python函数的调用慢。

是什么使jit-ted函数与众不同?

要了解为何调用numba jitted函数的速度较慢,必须理解numba jitted函数不再是一个函数。这是一个调度程序对象:

import numba as nb
@nb.njit
def nb_len(seq):
    return len(seq)
print(nb_len)  # CPUDispatcher(<function nb_len at 0x0000027EB1B4E798>)

CPUDispatcher实例(可能)表示基于修饰函数生成的多个编译函数。

这意味着当您调用CPUDispatcher实例时,需要执行多个步骤:

  • 获取参数的类型。
  • 如果没有适合这些参数类型的编译函数,请使用参数类型编译修饰后的函数。
  • 有时:将参数转换为相应的numba类型。
  • 调用已编译的函数。

与未装饰的功能相比,所有这些步骤都会增加开销。尤其是如果没有合适的已编译函数,并且调度程序需要编译该函数-或-输入类型需要转换(仅适用于Python类型,例如:列表,集合,字典),调用CPUDispatcher会慢很多-在numba 0.46中编写这些类型时,已经过时了,部分原因是因为,请参见"2.11.2. Deprecation of reflection for List and Set types"

以您的情况

在您的情况下,由于编译,第一次调用jitted函数的速度会大大降低。

任何后续调用只会稍微慢一些,因为numba必须获取参数类型,检查是否已存在编译函数,然后再调用该编译函数。有趣的是,额外的时间取决于参数的数量和该函数已编译的“重载”的数量。通常,额外的时间是微不足道的,因为该函数比调用len的功能要多得多。

编译时间

即使函数非常简单,第一次调用时的编译也会花费大量时间:

import numpy as np
import numba as nb

def first_call(seq):
    @nb.njit
    def nb_len(seq):
        return len(seq)
    return nb_len(seq)

@nb.njit
def _nb_len(seq):
    return len(seq)

def subsequent_calls(seq):
    return _nb_len(seq)

t = np.random.rand(1000)
_nb_len(np.ones(1, dtype=np.float64))

%timeit first_call(t)
# 29.8 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit subsequent_calls(t)
# 384 ns ± 6.02 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

转换时间

此外,如果numba需要转换参数,它将慢很多。这仅发生在numba无法直接处理的Python类型上,例如列表:

import numpy as np
import numba as nb

@nb.njit
def nb_len(seq):
    return len(seq)

arr = np.random.rand(10_000)
lst = arr.tolist()

nb_len(arr)
nb_len(lst)

%timeit nb_len(arr)
# 354 ns ± 24 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit nb_len(lst)
# 14.1 ms ± 950 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

摘要

  • 与普通的Python函数相比,Numba函数具有一些额外的开销。因此,请确保您做了“足够”的事情,使numba擅长优化,否则普通的Python函数将更快,更灵活且更易于调试。
  • numba函数中的函数调用与numba函数之外的函数调用确实可以不同。因此,len()中的nb_lenlen()中的py_len可以具有完全不同的运行时间。但是,在这种情况下,运行时间几乎相同。但是意识到这一点通常很好。
  • 根据参数类型,numba函数(在幕后)可能会非常慢,尤其是在将Python类型作为参数或返回类型的情况下!