numba fibonacci函数较慢并且不会返回相同的结果

时间:2017-09-19 12:47:19

标签: python python-2.7 performance numba

此代码返回缓慢且输出不同:

from numba import jit
from timeit import default_timer as timer
def fibonacci(n):
    a, b = 1, 1
    for i in range(n):
        a, b = a+b, a
    return a
fibonacci_jit = jit(fibonacci)

让我们开始测试

start = timer()
print fibonacci(100)
duration = timer() - start

让我们开始测试

startnext = timer()
print fibonacci_jit(100)
durationnext = timer() - startnext


print(duration, durationnext)

结果:

C:\Python27>python numba_test_003.py
927372692193078999176
1445263496
(0.00038264393810854576, 0.17378674127528523)

#next

C:\Python27>python numba_test_003.py
927372692193078999176
1445263496
(0.0004830358514597401, 0.19266426987655644)

2 个答案:

答案 0 :(得分:1)

由于你只运行一次Numba jitted函数,你会看到jit编译时间和运行时间的总和。下次运行numba函数时,您只会看到运行时,因为numba会为每个唯一的输入参数类型缓存已编译的代码,所以它会更快:

startnext = timer()
print fibonacci_jit(100)
durationnext = timer() - startnext
print(duration, durationnext)

#5035488507601418376
#(0.0003879070281982422, 0.14705300331115723)

startnext = timer()
print fibonacci_jit(100)
durationnext = timer() - startnext

print(duration, durationnext)

#5035488507601418376
#(0.0003879070281982422, 0.0002810955047607422)

答案的不同之处在于Python本机对象int具有无限精度,而numba使用的是具有有限容量且可能溢出的类似C的本机int。如果你运行较小输入的函数,你应该看到它同意,直到你溢出numba int。

答案 1 :(得分:1)

减速的原因是编译时间。第一次调用未签名的numba-jitted函数时,它将检查类型并为这些参数编译函数。后续运行会更快,因为它已经编译过:

for _ in range(5):
    start = timer()
    fibonacci_jit(100)
    print(timer() - start)

0.18958417814776496      # first run - includes compilation
6.1441049545862825e-06
3.3513299761978033e-06
3.3513299761978033e-06
3.3513299761978033e-06

但是,因为numba使用C类型,所以整数会溢出。您可以轻松检查类型:

fibonacci_jit.inspect_types()

fibonacci (int64,)
--------------------------------------------------------------------------------
# File: <ipython-input-19-a73271f1a552>
# --- LINE 3 --- 
# label 0
#   del $const0.1
#   del $0.4
#   del $0.2
#   del $0.3

def fibonacci(n):

    # --- LINE 4 --- 
    #   n = arg(0, name=n)  :: int64
    #   $const0.1 = const(tuple, (1, 1))  :: (int64 x 2)
    #   $0.4 = exhaust_iter(value=$const0.1, count=2)  :: (int64 x 2)
    #   $0.2 = static_getitem(value=$0.4, index=0, index_var=None)  :: int64
    #   $0.3 = static_getitem(value=$0.4, index=1, index_var=None)  :: int64
    #   a = $0.2  :: int64
    #   b = $0.3  :: int64
    #   jump 8
    # label 8

    a, b = 1, 1

    # --- LINE 5 --- 
    #   jump 10
    # label 10
    #   $10.1 = global(range: <class 'range'>)  :: Function(<class 'range'>)
    #   $10.3 = call $10.1(n, func=$10.1, args=[Var(n, <ipython-input-19-a73271f1a552> (4))], kws=(), vararg=None)  :: (int64,) -> range_state_int64
    #   del n
    #   del $10.1
    #   $10.4 = getiter(value=$10.3)  :: range_iter_int64
    #   del $10.3
    #   $phi18.1 = $10.4  :: range_iter_int64
    #   del $10.4
    #   jump 18
    # label 18
    #   $18.2 = iternext(value=$phi18.1)  :: pair<int64, bool>
    #   $18.3 = pair_first(value=$18.2)  :: int64
    #   $18.4 = pair_second(value=$18.2)  :: bool
    #   del $18.2
    #   $phi20.1 = $18.3  :: int64
    #   $phi38.1 = $18.3  :: int64
    #   del $phi38.1
    #   del $18.3
    #   $phi38.2 = $phi18.1  :: range_iter_int64
    #   del $phi38.2
    #   branch $18.4, 20, 38
    # label 20
    #   del $18.4
    #   i = $phi20.1  :: int64
    #   del i
    #   del $phi20.1
    #   del $20.4
    #   del $a20.5

    for i in range(n):

        # --- LINE 6 --- 
        #   $20.4 = a + b  :: int64
        #   $a20.5 = a  :: int64
        #   a = $20.4  :: int64
        #   b = $a20.5  :: int64
        #   jump 18
        # label 38
        #   del b
        #   del $phi20.1
        #   del $phi18.1
        #   del $18.4
        #   jump 40
        # label 40
        #   del a

        a, b = a+b, a

    # --- LINE 7 --- 
    #   $40.2 = cast(value=a)  :: int64
    #   return $40.2

    return a


================================================================================

至少在我的计算机上它使用int64,因此最大可能值为9223372036854775807。你不能用numba解决这个问题。如果你需要任意精度整数,你必须坚持使用Python。