我试图学习numba,因此作为入门练习,我编写了一个简单的轨道求解器:
import numba as nb
import numpy as np
from timeit import default_timer as timer
spec = [('x0', nb.types.float64),
('y0', nb.types.float64),
('vx0', nb.types.float64),
('vy0', nb.types.float64),
('mass', nb.types.float64),
('ax', nb.types.float64),
('ay', nb.types.float64),
('x', nb.types.float64[:]),
('y', nb.types.float64[:]),
('vx', nb.types.float64[:]),
('vy', nb.types.float64[:])]
@nb.jitclass(spec)
class CelestialBody():
def __init__(self, x0, y0, vx0, vy0, mass):
self.x0 = x0
self.y0 = y0
self.vx0 = vx0
self.vy0 = vy0
self.mass = mass
self.ax = 0.0
self.ay = 0.0
@nb.jit(nopython=True, cache=True)
def orbit(bodies, delta_t, nsteps):
# Set up position arrays
for j in range(len(bodies)):
bodies[j].x = np.zeros(nsteps, dtype=np.float64)
bodies[j].y = np.zeros(nsteps, dtype=np.float64)
bodies[j].vx = np.zeros(nsteps, dtype=np.float64)
bodies[j].vy = np.zeros(nsteps, dtype=np.float64)
bodies[j].x[0] = bodies[j].x0
bodies[j].y[0] = bodies[j].y0
bodies[j].vx[0] = bodies[j].vx0
bodies[j].vy[0] = bodies[j].vy0
# Loop over every time step (skip 0 since we have x0 and y0)
for i in range(0, nsteps-1):
# Get gravitational acceleration for each body at current time
for j in range(len(bodies)):
# Reset accelerations
bodies[j].ax = 0.0
bodies[j].ay = 0.0
for k in range(len(bodies)):
if j != k:
# Get distance between objects
dx = bodies[j].x[i] - bodies[k].x[i]
dy = bodies[j].y[i] - bodies[k].y[i]
d = np.sqrt(dx**2. + dy**2.)
# Get acceleration
a = -bodies[k].mass / d**2.
# Separate into x and y components
theta = np.arctan2(dy, dx)
bodies[j].ax += a * np.cos(theta)
bodies[j].ay += a * np.sin(theta)
# Update positions
for j in range(len(bodies)):
bodies[j].vx[i+1] += bodies[j].vx[i] + bodies[j].ax * delta_t
bodies[j].vy[i+1] += bodies[j].vy[i] + bodies[j].ay * delta_t
bodies[j].x[i+1] += bodies[j].x[i] + bodies[j].vx[i] * delta_t +\
0.5 * bodies[j].ax * delta_t**2.
bodies[j].y[i+1] += bodies[j].y[i] + bodies[j].vy[i] * delta_t + 0.5 *\
bodies[j].ay * delta_t**2
return bodies
for i in range(10):
# Set up celestial bodies
sun = CelestialBody(0., 0., 0., 0., 1.)
earth = CelestialBody(1., 0., 0., 6.33, 3.00e-6)
bodies = [sun, earth]
# Set up time info
tf = 100.
delta_t = tf / 365.
nsteps = int(tf / delta_t)
# Orbit
start = timer()
bodies = orbit(bodies, delta_t, nsteps)
end = timer()
print('Time to run: %f' % (end - start))
该代码可以在没有numba的情况下运行。当我添加numba时,我可以同时调试我的类和函数,并且可以很好地运行,并提供了很好的速度。但是,当我尝试使用cache = True缓存jitt'ed函数时,出现KeyError:
File "/usr/local/lib/python3.6/dist-packages/numba/caching.py", line 482, in save
data_name = overloads[key]
KeyError: ((reflected list(instance.jitclass.CelestialBody#2cef1b8<x0:float64,
y0:float64,vx0:float64,vy0:float64,mass:float64,ax:float64,ay:float64,
x:array(float64, 1d, A),y:array(float64, 1d, A),vx:array(float64, 1d, A),
vy:array(float64, 1d, A)>), float64, int64), ('x86_64-unknown-linux-gnu',
'skylake', '+adx,+aes,+avx,+avx2,-avx512bitalg,-avx512bw,-avx512cd,-avx512dq,
-avx512er,-avx512f,-avx512ifma,-avx512pf,-avx512vbmi,-avx512vbmi2,-avx512vl,
-avx512vnni,-avx512vpopcntdq,+bmi,+bmi2,-cldemote,+clflushopt,-clwb,-clzero,+cmov,
+cx16,+f16c,+fma,-fma4,+fsgsbase,-gfni,+invpcid,-lwp,+lzcnt,+mmx,+movbe,-movdir64b,
-movdiri,-mwaitx,+pclmul,-pconfig,-pku,+popcnt,-prefetchwt1,+prfchw,-ptwrite,
-rdpid,+rdrnd,+rdseed,-rtm,+sahf,+sgx,-sha,-shstk,+sse,+sse2,+sse3,+sse4.1,
+sse4.2,-sse4a,+ssse3,-tbm,-vaes,-vpclmulqdq,-waitpkg,-wbnoinvd,-xop,+xsave,
+xsavec,+xsaveopt,+xsaves'))
我意识到上面的大多数内容都是编译器标志,并且可能是不必要的,但是我不确定,因此我想将其包括在内。
还有一个泡菜错误:
_pickle.PicklingError: Can't pickle <class '__main__.CelestialBody'>: it's not the same object as __main__.CelestialBody
我尝试查看this question,但据我所知没有导入错误,而且我也没有弄错我要导入的任何模块。我也没有在jupyter笔记本中运行,而只是在终端上运行。我的猜测是,它与类“ signature”在编译之前和之后都有关系,而pickle对此更改感到困惑。不使用类时,我可以使缓存起作用。
我正在使用Python版本3.6.7,numpy版本1.15.4和numba版本0.42.1
所以,我的问题是什么导致此腌制错误阻止了缓存?谢谢!