我正在学习numba,遇到这种我不理解的“奇怪”行为。 我尝试使用以下代码(在iPython中,用于计时):
import numpy as np
import numba as nb
@nb.njit
def nb_len(seq):
return len(seq)
def py_len(seq):
return len(seq)
##
t = np.random.rand(1000)
%timeit nb_len(t)
%timeit py_len(t)
结果如下(实际上是第二次运行,因为编译了numba):
258 ns ± 1.37 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
137 ns ± 0.964 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
纯python版本的速度是numba版本的两倍。
我也尝试使用签名@nb.njit( nb.int32(nb.float64[:]) )
,但结果仍然相同。
我在某个地方犯错了吗?
谢谢。
答案 0 :(得分:4)
添加时间的不是len()部分。使用输入参数调用jit函数会增加开销,这就是您所看到的时差。
bulk.execute();
import numba as nb
def py_pass(i):
return i
@nb.njit()
def nb_pass(i):
return i
%timeit py_pass(1)
%timeit nb_pass(1)
有趣的是,如果您不需要将任何东西传递给jit函数,它会更快:
102 ns ± 0.371 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
165 ns ± 0.783 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
def py_pass():
return 1
@nb.njit()
def nb_pass():
return 1
%timeit py_pass()
%timeit nb_pass()
答案 1 :(得分:2)
正如other answer所说的,这不是因为使用len
函数,而是因为对numba函数的调用实际上比对普通Python函数的调用慢。
jit
-ted函数与众不同?要了解为何调用numba jitted函数的速度较慢,必须理解numba jitted函数不再是一个函数。这是一个调度程序对象:
import numba as nb
@nb.njit
def nb_len(seq):
return len(seq)
print(nb_len) # CPUDispatcher(<function nb_len at 0x0000027EB1B4E798>)
此CPUDispatcher
实例(可能)表示基于修饰函数生成的多个编译函数。
这意味着当您调用CPUDispatcher
实例时,需要执行多个步骤:
与未装饰的功能相比,所有这些步骤都会增加开销。尤其是如果没有合适的已编译函数,并且调度程序需要编译该函数-或-输入类型需要转换(仅适用于Python类型,例如:列表,集合,字典),调用CPUDispatcher
会慢很多-在numba 0.46中编写这些类型时,已经过时了,部分原因是因为,请参见"2.11.2. Deprecation of reflection for List and Set types"。
在您的情况下,由于编译,第一次调用jitted函数的速度会大大降低。
任何后续调用只会稍微慢一些,因为numba必须获取参数类型,检查是否已存在编译函数,然后再调用该编译函数。有趣的是,额外的时间取决于参数的数量和该函数已编译的“重载”的数量。通常,额外的时间是微不足道的,因为该函数比调用len
的功能要多得多。
即使函数非常简单,第一次调用时的编译也会花费大量时间:
import numpy as np
import numba as nb
def first_call(seq):
@nb.njit
def nb_len(seq):
return len(seq)
return nb_len(seq)
@nb.njit
def _nb_len(seq):
return len(seq)
def subsequent_calls(seq):
return _nb_len(seq)
t = np.random.rand(1000)
_nb_len(np.ones(1, dtype=np.float64))
%timeit first_call(t)
# 29.8 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit subsequent_calls(t)
# 384 ns ± 6.02 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
此外,如果numba需要转换参数,它将慢很多。这仅发生在numba无法直接处理的Python类型上,例如列表:
import numpy as np
import numba as nb
@nb.njit
def nb_len(seq):
return len(seq)
arr = np.random.rand(10_000)
lst = arr.tolist()
nb_len(arr)
nb_len(lst)
%timeit nb_len(arr)
# 354 ns ± 24 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit nb_len(lst)
# 14.1 ms ± 950 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
len()
中的nb_len
和len()
中的py_len
可以具有完全不同的运行时间。但是,在这种情况下,运行时间几乎相同。但是意识到这一点通常很好。