我有一些代码我试图加速numba。我已经对这个主题做了一些阅读,但我还没有能够100%解决这个问题。
以下是代码:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
import seaborn as sns
from numba import jit, vectorize, float64, autojit
sns.set(context='talk', style='ticks', font_scale=1.2, rc={'figure.figsize': (6.5, 5.5), 'xtick.direction': 'in', 'ytick.direction': 'in'})
#%% constraints
x_min = 0 # death below this
x_max = 20 # maximum weight
t_max = 100 # maximum time
foraging_efficiencies = np.linspace(0, 1, 10) # potential foraging efficiencies
R = 10.0 # Resource level
#%% make the body size and time categories
body_sizes = np.arange(x_min, x_max+1)
time_steps = np.arange(t_max)
#%% parameter functions
@jit
def metabolic_fmr(x, u,temp): # metabolic cost function
fmr = 0.125*(2**(0.2*temp))*(1 + 0.5*u) + x*0.1
return fmr
def intake_dist(u): # intake stochastic function (returns a vector)
g = st.binom.pmf(np.arange(R+1), R, u)
return g
@jit
def mass_gain(x, u, temp): # mass gain function (returns a vector)
x_prime = x - metabolic_fmr(x, u,temp) + np.arange(R+1)
x_prime = np.minimum(x_prime, x_max)
x_prime = np.maximum(x_prime, 0)
return x_prime
@jit
def prob_attack(P): # probability of an attack
p_a = 0.02*P
return p_a
@jit
def prob_see(u): # probability of not seeing an attack
p_s = 1-(1-u)**0.3
return p_s
@jit
def prob_lethal(x): # probability of lethality given a successful attack
p_l = 0.5*np.exp(-0.05*x)
return p_l
@jit
def prob_mort(P, u, x):
p_m = prob_attack(P)*prob_see(u)*prob_lethal(x)
return np.minimum(p_m, 1)
#%% terminal fitness function
@jit
def terminal_fitness(x):
t_f = 15.0*x/(x+5.0)
return t_f
#%% linear interpolation function
@jit
def linear_interpolation(x, F, t):
floor = x.astype(int)
delta_c = x-floor
ceiling = floor + 1
ceiling[ceiling>x_max] = x_max
floor[floor<x_min] = x_min
interpolated_F = (1-delta_c)*F[floor,t] + (delta_c)*F[ceiling,t]
return interpolated_F
#%% solver
@jit
def solver_jit(P, temp):
F = np.zeros((len(body_sizes), len(time_steps))) # Expected fitness
F[:,-1] = terminal_fitness(body_sizes) # expected terminal fitness for every body size
V = np.zeros((len(foraging_efficiencies), len(body_sizes), len(time_steps))) # Fitness for each foraging effort
D = np.zeros((len(body_sizes), len(time_steps))) # Decision
for t in range(t_max-1)[::-1]:
for x in range(x_min+1, x_max+1): # iterate over every body size except dead
for i in range(len(foraging_efficiencies)): # iterate over every possible foraging efficiency
u = foraging_efficiencies[i]
g_u = intake_dist(u) # calculate the distribution of intakes
xp = mass_gain(x, u, temp) # calculate the mass gain
p_m = prob_mort(P, u, x) # probability of mortality
V[i,x,t] = (1 - p_m)*(linear_interpolation(xp, F, t+1)*g_u).sum() # Fitness calculation
vmax = V[:,x,t].max()
idx = np.argwhere(V[:,x,t]==vmax).min()
D[x,t] = foraging_efficiencies[idx]
F[x,t] = vmax
return D, F
def solver_norm(P, temp):
F = np.zeros((len(body_sizes), len(time_steps))) # Expected fitness
F[:,-1] = terminal_fitness(body_sizes) # expected terminal fitness for every body size
V = np.zeros((len(foraging_efficiencies), len(body_sizes), len(time_steps))) # Fitness for each foraging effort
D = np.zeros((len(body_sizes), len(time_steps))) # Decision
for t in range(t_max-1)[::-1]:
for x in range(x_min+1, x_max+1): # iterate over every body size except dead
for i in range(len(foraging_efficiencies)): # iterate over every possible foraging efficiency
u = foraging_efficiencies[i]
g_u = intake_dist(u) # calculate the distribution of intakes
xp = mass_gain(x, u, temp) # calculate the mass gain
p_m = prob_mort(P, u, x) # probability of mortality
V[i,x,t] = (1 - p_m)*(linear_interpolation(xp, F, t+1)*g_u).sum() # Fitness calculation
vmax = V[:,x,t].max()
idx = np.argwhere(V[:,x,t]==vmax).min()
D[x,t] = foraging_efficiencies[idx]
F[x,t] = vmax
return D, F
单个jit函数往往比未jitted函数快得多。例如,一旦通过jit运行,prob_mort的速度提高约600%。但是,解算器本身并不快:
In [3]: %timeit -n 10 solver_jit(200, 25)
10 loops, best of 3: 3.94 s per loop
In [4]: %timeit -n 10 solver_norm(200, 25)
10 loops, best of 3: 4.09 s per loop
我知道有些函数不能被jitted,所以我用自定义jit函数替换了st.binom.pmf函数,实际上每个循环的时间减慢到大约17s,慢了5倍。大概是因为scipy函数在这一点上得到了很大的优化。
所以我怀疑慢度是在linear_interpolate函数中还是在jitted函数之外的求解器代码中的某个地方(因为在某一点上我取消了所有函数并运行了solver_norm并得到了相同的时间)。关于缓慢部分的位置以及如何加快速度的想法?
更新
这是我用来试图加速jit的二项式代码
@jit
def factorial(n):
if n==0:
return 1
else:
return n*factorial(n-1)
@vectorize([float64(float64,float64,float64)])
def binom(k, n, p):
binom_coef = factorial(n)/(factorial(k)*factorial(n-k))
pmf = binom_coef*p**k*(1-p)**(n-k)
return pmf
@jit
def intake_dist(u): # intake stochastic function (returns a vector)
g = binom(np.arange(R+1), R, u)
return g
更新2 我尝试在nopython模式下运行我的二项式代码,发现我做错了因为它是递归的。通过将代码更改为:
来修复它@jit(int64(int64), nopython=True)
def factorial(nn):
res = 1
for ii in range(2, nn + 1):
res *= ii
return res
@vectorize([float64(float64,float64,float64)], nopython=True)
def binom(k, n, p):
binom_coef = factorial(n)/(factorial(k)*factorial(n-k))
pmf = binom_coef*p**k*(1-p)**(n-k)
return pmf
解算器现在在
运行In [34]: %timeit solver_jit(200, 25)
1 loop, best of 3: 921 ms per loop
快约3.5倍。但是,solver_jit()和solver_norm()仍以相同的速度运行,这意味着jit函数之外的某些代码会降低它的速度。
答案 0 :(得分:1)
我能够对您的代码进行一些更改,以便jit版本可以在nopython
模式下完全编译。在我的笔记本电脑上,这导致:
%timeit solver_jit(200, 25)
1 loop, best of 3: 50.9 ms per loop
%timeit solver_norm(200, 25)
1 loop, best of 3: 192 ms per loop
作为参考,我使用的是Numba 0.27.0。我承认Numba的编译错误仍然很难确定发生了什么,但是因为我已经玩了一段时间,我已经建立了一个直觉来解决需要修复的问题。完整的代码如下,但这里是我所做的更改列表:
linear_interpolation
中将x.astype(int)
更改为x.astype(np.int64)
,以便可以nopython
模式进行编译。 np.sum
作为函数而不是数组的方法。np.argwhere
不受支持。写一个自定义循环。可能会进行一些进一步的优化,但这会带来初步加速。
完整代码:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
import seaborn as sns
from numba import jit, vectorize, float64, autojit, njit
sns.set(context='talk', style='ticks', font_scale=1.2, rc={'figure.figsize': (6.5, 5.5), 'xtick.direction': 'in', 'ytick.direction': 'in'})
#%% constraints
x_min = 0 # death below this
x_max = 20 # maximum weight
t_max = 100 # maximum time
foraging_efficiencies = np.linspace(0, 1, 10) # potential foraging efficiencies
R = 10.0 # Resource level
#%% make the body size and time categories
body_sizes = np.arange(x_min, x_max+1)
time_steps = np.arange(t_max)
#%% parameter functions
@njit
def metabolic_fmr(x, u,temp): # metabolic cost function
fmr = 0.125*(2**(0.2*temp))*(1 + 0.5*u) + x*0.1
return fmr
@njit()
def factorial(nn):
res = 1
for ii in range(2, nn + 1):
res *= ii
return res
@vectorize([float64(float64,float64,float64)], nopython=True)
def binom(k, n, p):
binom_coef = factorial(n)/(factorial(k)*factorial(n-k))
pmf = binom_coef*p**k*(1-p)**(n-k)
return pmf
@njit
def intake_dist(u): # intake stochastic function (returns a vector)
g = binom(np.arange(R+1), R, u)
return g
@njit
def mass_gain(x, u, temp): # mass gain function (returns a vector)
x_prime = x - metabolic_fmr(x, u,temp) + np.arange(R+1)
x_prime = np.minimum(x_prime, x_max)
x_prime = np.maximum(x_prime, 0)
return x_prime
@njit
def prob_attack(P): # probability of an attack
p_a = 0.02*P
return p_a
@njit
def prob_see(u): # probability of not seeing an attack
p_s = 1-(1-u)**0.3
return p_s
@njit
def prob_lethal(x): # probability of lethality given a successful attack
p_l = 0.5*np.exp(-0.05*x)
return p_l
@njit
def prob_mort(P, u, x):
p_m = prob_attack(P)*prob_see(u)*prob_lethal(x)
return np.minimum(p_m, 1)
#%% terminal fitness function
@njit
def terminal_fitness(x):
t_f = 15.0*x/(x+5.0)
return t_f
#%% linear interpolation function
@njit
def linear_interpolation(x, F, t):
floor = x.astype(np.int64)
delta_c = x-floor
ceiling = floor + 1
ceiling[ceiling>x_max] = x_max
floor[floor<x_min] = x_min
interpolated_F = (1-delta_c)*F[floor,t] + (delta_c)*F[ceiling,t]
return interpolated_F
#%% solver
@njit
def solver_jit(P, temp):
F = np.zeros((len(body_sizes), len(time_steps))) # Expected fitness
F[:,-1] = terminal_fitness(body_sizes) # expected terminal fitness for every body size
V = np.zeros((len(foraging_efficiencies), len(body_sizes), len(time_steps))) # Fitness for each foraging effort
D = np.zeros((len(body_sizes), len(time_steps))) # Decision
for t in range(t_max-2,-1,-1):
for x in range(x_min+1, x_max+1): # iterate over every body size except dead
for i in range(len(foraging_efficiencies)): # iterate over every possible foraging efficiency
u = foraging_efficiencies[i]
g_u = intake_dist(u) # calculate the distribution of intakes
xp = mass_gain(x, u, temp) # calculate the mass gain
p_m = prob_mort(P, u, x) # probability of mortality
V[i,x,t] = (1 - p_m)*np.sum((linear_interpolation(xp, F, t+1)*g_u)) # Fitness calculation
vmax = V[:,x,t].max()
for k in xrange(V.shape[0]):
if V[k,x,t] == vmax:
idx = k
break
#idx = np.argwhere(V[:,x,t]==vmax).min()
D[x,t] = foraging_efficiencies[idx]
F[x,t] = vmax
return D, F
def solver_norm(P, temp):
F = np.zeros((len(body_sizes), len(time_steps))) # Expected fitness
F[:,-1] = terminal_fitness(body_sizes) # expected terminal fitness for every body size
V = np.zeros((len(foraging_efficiencies), len(body_sizes), len(time_steps))) # Fitness for each foraging effort
D = np.zeros((len(body_sizes), len(time_steps))) # Decision
for t in range(t_max-1)[::-1]:
for x in range(x_min+1, x_max+1): # iterate over every body size except dead
for i in range(len(foraging_efficiencies)): # iterate over every possible foraging efficiency
u = foraging_efficiencies[i]
g_u = intake_dist(u) # calculate the distribution of intakes
xp = mass_gain(x, u, temp) # calculate the mass gain
p_m = prob_mort(P, u, x) # probability of mortality
V[i,x,t] = (1 - p_m)*(linear_interpolation(xp, F, t+1)*g_u).sum() # Fitness calculation
vmax = V[:,x,t].max()
idx = np.argwhere(V[:,x,t]==vmax).min()
D[x,t] = foraging_efficiencies[idx]
F[x,t] = vmax
return D, F
答案 1 :(得分:0)
如上所述,可能会有一些代码回退到对象模式。我只是想补充一点,你可以使用njit而不是jit来禁用对象模式。这将有助于诊断哪些代码是罪魁祸首。