为什么每100k迭代打印一次会破坏numba性能?

时间:2018-12-06 16:19:05

标签: python printing jit numba

为什么要编写此代码,每100k迭代一次print(即仅打印40行!)需要50秒才能运行:

import numpy as np
from numba import jit

@jit
def doit():
    A = np.random.random(4*1000*1000)
    n = 300
    Q = np.zeros(len(A)-n)
    for i in range(len(Q)):
        Q[i] = np.sum(A[i:i+n] <= A[i+n])
        if i % 100000 == 0:  # print the progress once every 100k iterations
            print("%i %.2f %% already done. " % (i, i * 100.0 / len(A)))

doit()

在没有print的情况下,仅需2.4秒

import numpy as np
from numba import jit
@jit
def doit():
    A = np.random.random(4*1000*1000)
    n = 300
    Q = np.zeros(len(A)-n)
    for i in range(len(Q)):
        Q[i] = np.sum(A[i:i+n] <= A[i+n])
doit()

print真的可以消除numba的好处吗?

1 个答案:

答案 0 :(得分:3)

如果尝试使用@njit@jit(nopython=True)进行编译,则会从异常中看到它正在对象模式下进行编译。使用以下打印语句,该版本在我的计算机上运行大约1秒钟:

import numpy as np
from numba import jit

@jit(nopython=True)
def doit():
    A = np.random.random(4*1000*1000)
    n = 300
    Q = np.zeros(len(A)-n)
    for i in range(len(Q)):
        Q[i] = np.sum(A[i:i+n] <= A[i+n])
        if i % 100000 == 0:  # print the progress once every 100k iterations
            print(i , "(",  i * 100.0 / len(A), '% already done)')

通常,如果您发现numba函数的性能不佳,那是因为您是在python对象模式下进行编译,因此除非您确实想在python对象模式下使用它,否则始终将nopython=True放置是一个好习惯因为如果遇到某种语法,编译器就无法将其编译为机器代码,这将归结为这种情况。 Numba确实做了一些循环提升,但是就性能而言,这很难推理。

请参阅:

http://numba.pydata.org/numba-doc/latest/user/5minguide.html#what-is-nopython-mode