我想了解使用Numba在for循环中加速纯numpy
代码时的收益来源。是否有任何分析工具可以让您查看jitted
函数?
演示代码(如下所示)只是使用非常基本的矩阵乘法来为计算机提供工作。观察到的收益来自:
loop
,numpy
在编译过程中截获的jit
个函数,或jit
减少开销,因为numpy通过包装函数将执行外包到低级库,例如LINPACK
%matplotlib inline
import numpy as np
from numba import jit
import pandas as pd
#Dimensions of Matrices
i = 100
j = 100
def pure_python(N,i,j):
for n in range(N):
a = np.random.rand(i,j)
b = np.random.rand(i,j)
c = np.dot(a,b)
@jit(nopython=True)
def jit_python(N,i,j):
for n in range(N):
a = np.random.rand(i,j)
b = np.random.rand(i,j)
c = np.dot(a,b)
time_python = []
time_jit = []
N = [1,10,100,500,1000,2000]
for n in N:
time = %timeit -oq pure_python(n,i,j)
time_python.append(time.average)
time = %timeit -oq jit_python(n,i,j)
time_jit.append(time.average)
df = pd.DataFrame({'pure_python' : time_python, 'jit_python' : time_jit}, index=N)
df.index.name = 'Iterations'
df[["pure_python", "jit_python"]].plot()
生成以下图表。
答案 0 :(得分:7)
TL:DR随机和循环加速,但矩阵乘法除了小矩阵大小外没有。在小矩阵/循环大小,似乎有可能与python开销有关的显着加速。在大N时,矩阵乘法开始占主导地位且jit不太有用
函数定义,为简单起见使用方阵。
from IPython.display import display
import numpy as np
from numba import jit
import pandas as pd
#Dimensions of Matrices
N = 1000
def py_rand(i, j):
a = np.random.rand(i, j)
jit_rand = jit(nopython=True)(py_rand)
def py_matmul(a, b):
c = np.dot(a, b)
jit_matmul = jit(nopython=True)(py_matmul)
def py_loop(N, val):
count = 0
for i in range(N):
count += val
jit_loop = jit(nopython=True)(py_loop)
def pure_python(N,i,j):
for n in range(N):
a = np.random.rand(i,j)
b = np.random.rand(i,j)
c = np.dot(a,a)
jit_func = jit(nopython=True)(pure_python)
定时:
df = pd.DataFrame(columns=['Func', 'jit', 'N', 'Time'])
def meantime(f, *args, **kwargs):
t = %timeit -oq -n5 f(*args, **kwargs)
return t.average
for N in [10, 100, 1000, 2000]:
a = np.random.randn(N, N)
b = np.random.randn(N, N)
df = df.append({'Func': 'jit_rand', 'N': N, 'Time': meantime(jit_rand, N, N)}, ignore_index=True)
df = df.append({'Func': 'py_rand', 'N': N, 'Time': meantime(py_rand, N, N)}, ignore_index=True)
df = df.append({'Func': 'jit_matmul', 'N': N, 'Time': meantime(jit_matmul, a, b)}, ignore_index=True)
df = df.append({'Func': 'py_matmul', 'N': N, 'Time': meantime(py_matmul, a, b)}, ignore_index=True)
df = df.append({'Func': 'jit_loop', 'N': N, 'Time': meantime(jit_loop, N, 2.0)}, ignore_index=True)
df = df.append({'Func': 'py_loop', 'N': N, 'Time': meantime(py_loop, N, 2.0)}, ignore_index=True)
df = df.append({'Func': 'jit_func', 'N': N, 'Time': meantime(jit_func, 5, N, N)}, ignore_index=True)
df = df.append({'Func': 'py_func', 'N': N, 'Time': meantime(pure_python, 5, N, N)}, ignore_index=True)
df['jit'] = df['Func'].str.contains('jit')
df['Func'] = df['Func'].apply(lambda s: s.split('_')[1])
df.set_index('Func')
display(df)
结果:
Func jit N Time
0 rand True 10 1.030686e-06
1 rand False 10 1.115149e-05
2 matmul True 10 2.250371e-06
3 matmul False 10 2.199343e-06
4 loop True 10 2.706000e-07
5 loop False 10 7.274286e-07
6 func True 10 1.217046e-05
7 func False 10 2.495837e-05
8 rand True 100 5.199217e-05
9 rand False 100 8.149794e-05
10 matmul True 100 7.848071e-05
11 matmul False 100 2.130794e-05
12 loop True 100 2.728571e-07
13 loop False 100 3.003743e-06
14 func True 100 6.739634e-04
15 func False 100 1.146594e-03
16 rand True 1000 5.644258e-03
17 rand False 1000 8.012790e-03
18 matmul True 1000 1.476098e-02
19 matmul False 1000 1.613211e-02
20 loop True 1000 2.846572e-07
21 loop False 1000 3.539849e-05
22 func True 1000 1.256926e-01
23 func False 1000 1.581177e-01
24 rand True 2000 2.061612e-02
25 rand False 2000 3.204709e-02
26 matmul True 2000 9.866484e-02
27 matmul False 2000 1.007234e-01
28 loop True 2000 3.011143e-07
29 loop False 2000 7.477454e-05
30 func True 2000 1.033560e+00
31 func False 2000 1.199969e+00
看起来numba正在优化循环,所以我不打算在比较中包括它
情节:
def jit_speedup(d):
py_time = d[d['jit'] == False]['Time'].mean()
jit_time = d[d['jit'] == True]['Time'].mean()
return py_time / jit_time
import seaborn as sns
result = df.groupby(['Func', 'N']).apply(jit_speedup).reset_index().rename(columns={0: 'Jit Speedup'})
result = result[result['Func'] != 'loop']
sns.factorplot(data=result, x='N', y='Jit Speedup', hue='Func')
因此,对于5次重复的循环,jit可以非常稳定地加速,直到矩阵乘法变得足够昂贵,相比之下,其他开销变得微不足道。