Python Numba / jit条件和递归(堆栈)使用

时间:2015-04-12 08:01:29

标签: python jit numba

所有

我正在使用numba JIT来加速我的Python代码,但即使numba&未安装LLVM。

我的第一个想法是这样做:

use_numba = True
try:
    from numba import jit, int32
except ImportError, e:
    use_numba = False

def run_it(parameters):
    # do something
    pass

# define wrapper call function with optimizer
@jit
def run_it_with_numba(parameters):
    return run_it(parameters)

# [...]
# main program 
t_start = timeit.default_timer()

# this is the code I don't like 
if use_numba:
    res = run_it_with_numba(parameters)
else:
    res = run_it(parameters)

t_stop = timeit.default_timer()
print "Numba: ", use_numba, " Time: ", t_stop - t_start

这不能像我预期的那样工作,因为编译似乎只适用于run_it_with_numba()函数 - 它基本上什么都不做 - 但不是从该函数调用的子程序。

当我在包含工作负载的函数上应用@jit时,结果才会变得更好。

是否有机会避免主程序中的包装函数和if子句?

有没有办法告诉Numba优化从我的输入函数调用的子程序?因为run_it()还包含一些函数调用,我希望@jit能够处理它。

铜, 麦酒

2 个答案:

答案 0 :(得分:5)

如果未安装Numba,您可以提供jit的无操作版本:

use_numba = True
try:
    from numba import jit, int32
except ImportError, e:
    use_numba = False
    from _shim import jit, int32

@jit
def run_it(parameters):
    # do something
    pass

# [...]
# main program 
t_start = timeit.default_timer()

res = run_it(eval(row[0]), workfeed, instrument)

t_stop = timeit.default_timer()
print "Numba: ", use_numba, " Time: ", t_stop - t_start

_shim.py只包含:

def jit(*args, **kwargs):
    def wrapper(f):
        return f
    if len(args) > 0 and (args[0] is marker or not callable(args[0])) \
        or len(kwargs) > 0:
        # @jit(int32(int32, int32)), @jit(signature="void(int32)")
        return wrapper
    elif len(args) == 0:
        # @jit()
        return wrapper
    else:
        # @jit
        return args[0]

def marker(*args, **kwargs): return marker

int32 = marker

答案 1 :(得分:1)

我认为你想以不同的方式做到这一点。而不是包装方法,只是选择别名。例如,使用虚拟方法来实现实际时间:

import numpy as np
import timeit 

use_numba = False
try:
    import numba as nb
except ImportError, e:
    use_numba = False

def _run_it(a, N):
    s = 0.0
    for k in xrange(N):
        s += k / np.sin(a)

    return s

# define wrapper call function with optimizer
if use_numba:
    print 'Using numba'
    run_it = nb.jit()(_run_it)
else:
    print 'Falling back to python'
    run_it = _run_it

if __name__ == '__main__':
    print timeit.repeat('run_it(50.0, 100000)', setup='from __main__ import run_it', repeat=3, number=100)

使用use_numba标记为True

运行此标记
$ python nbtest.py
Using numba
[0.18746304512023926, 0.15185213088989258, 0.1636970043182373]

False

$ python nbtest.py
Falling back to python
[9.707707166671753, 9.779848098754883, 9.770231008529663]

或在iPython笔记本中使用漂亮的%timeit魔法:

run_it_numba = nb.jit()(_run_it)

%timeit _run_it(50.0, 10000)
100 loops, best of 3: 9.51 ms per loop

%timeit run_it_numba(50.0, 10000)  
10000 loops, best of 3: 144 µs per loop

请注意,在对numba方法进行计时时,单次执行该方法的时间将考虑numba jit方法所需的时间。所有后续运行都会快得多。