所有
我正在使用numba JIT来加速我的Python代码,但即使numba&未安装LLVM。
我的第一个想法是这样做:
use_numba = True
try:
from numba import jit, int32
except ImportError, e:
use_numba = False
def run_it(parameters):
# do something
pass
# define wrapper call function with optimizer
@jit
def run_it_with_numba(parameters):
return run_it(parameters)
# [...]
# main program
t_start = timeit.default_timer()
# this is the code I don't like
if use_numba:
res = run_it_with_numba(parameters)
else:
res = run_it(parameters)
t_stop = timeit.default_timer()
print "Numba: ", use_numba, " Time: ", t_stop - t_start
这不能像我预期的那样工作,因为编译似乎只适用于run_it_with_numba()函数 - 它基本上什么都不做 - 但不是从该函数调用的子程序。
当我在包含工作负载的函数上应用@jit时,结果才会变得更好。
是否有机会避免主程序中的包装函数和if子句?
有没有办法告诉Numba优化从我的输入函数调用的子程序?因为run_it()还包含一些函数调用,我希望@jit能够处理它。
铜, 麦酒
答案 0 :(得分:5)
如果未安装Numba,您可以提供jit
的无操作版本:
use_numba = True
try:
from numba import jit, int32
except ImportError, e:
use_numba = False
from _shim import jit, int32
@jit
def run_it(parameters):
# do something
pass
# [...]
# main program
t_start = timeit.default_timer()
res = run_it(eval(row[0]), workfeed, instrument)
t_stop = timeit.default_timer()
print "Numba: ", use_numba, " Time: ", t_stop - t_start
_shim.py
只包含:
def jit(*args, **kwargs):
def wrapper(f):
return f
if len(args) > 0 and (args[0] is marker or not callable(args[0])) \
or len(kwargs) > 0:
# @jit(int32(int32, int32)), @jit(signature="void(int32)")
return wrapper
elif len(args) == 0:
# @jit()
return wrapper
else:
# @jit
return args[0]
def marker(*args, **kwargs): return marker
int32 = marker
答案 1 :(得分:1)
我认为你想以不同的方式做到这一点。而不是包装方法,只是选择别名。例如,使用虚拟方法来实现实际时间:
import numpy as np
import timeit
use_numba = False
try:
import numba as nb
except ImportError, e:
use_numba = False
def _run_it(a, N):
s = 0.0
for k in xrange(N):
s += k / np.sin(a)
return s
# define wrapper call function with optimizer
if use_numba:
print 'Using numba'
run_it = nb.jit()(_run_it)
else:
print 'Falling back to python'
run_it = _run_it
if __name__ == '__main__':
print timeit.repeat('run_it(50.0, 100000)', setup='from __main__ import run_it', repeat=3, number=100)
使用use_numba
标记为True
:
$ python nbtest.py
Using numba
[0.18746304512023926, 0.15185213088989258, 0.1636970043182373]
和False
:
$ python nbtest.py
Falling back to python
[9.707707166671753, 9.779848098754883, 9.770231008529663]
或在iPython笔记本中使用漂亮的%timeit
魔法:
run_it_numba = nb.jit()(_run_it)
%timeit _run_it(50.0, 10000)
100 loops, best of 3: 9.51 ms per loop
%timeit run_it_numba(50.0, 10000)
10000 loops, best of 3: 144 µs per loop
请注意,在对numba方法进行计时时,单次执行该方法的时间将考虑numba jit方法所需的时间。所有后续运行都会快得多。