Numba泡菜错误:无法泡菜类<class-name>:它与类名

时间:2019-02-21 17:09:58

标签: python python-3.x pickle numba

我试图学习numba,因此作为入门练习,我编写了一个简单的轨道求解器:

    import numba as nb
    import numpy as np
    from timeit import default_timer as timer

    spec = [('x0', nb.types.float64),
        ('y0', nb.types.float64),
        ('vx0', nb.types.float64),
        ('vy0', nb.types.float64),
        ('mass', nb.types.float64),
        ('ax', nb.types.float64),
        ('ay', nb.types.float64),
        ('x', nb.types.float64[:]),
        ('y', nb.types.float64[:]),
        ('vx', nb.types.float64[:]),
        ('vy', nb.types.float64[:])]
    @nb.jitclass(spec)
    class CelestialBody():
        def __init__(self, x0, y0, vx0, vy0, mass):
            self.x0 = x0
            self.y0 = y0
            self.vx0 = vx0
            self.vy0 = vy0
            self.mass = mass
            self.ax = 0.0
            self.ay = 0.0

    @nb.jit(nopython=True, cache=True)
    def orbit(bodies, delta_t, nsteps):
        # Set up position arrays
        for j in range(len(bodies)):
            bodies[j].x = np.zeros(nsteps, dtype=np.float64)
            bodies[j].y = np.zeros(nsteps, dtype=np.float64)
            bodies[j].vx = np.zeros(nsteps, dtype=np.float64)
            bodies[j].vy = np.zeros(nsteps, dtype=np.float64)
            bodies[j].x[0] = bodies[j].x0
            bodies[j].y[0] = bodies[j].y0
            bodies[j].vx[0] = bodies[j].vx0
            bodies[j].vy[0] = bodies[j].vy0
        # Loop over every time step (skip 0 since we have x0 and y0)
        for i in range(0, nsteps-1):
            # Get gravitational acceleration for each body at current time
            for j in range(len(bodies)):
                # Reset accelerations
                bodies[j].ax = 0.0
                bodies[j].ay = 0.0
                for k in range(len(bodies)):
                    if j != k:
                        # Get distance between objects
                        dx = bodies[j].x[i] - bodies[k].x[i]
                        dy = bodies[j].y[i] - bodies[k].y[i]
                        d = np.sqrt(dx**2. + dy**2.)
                        # Get acceleration
                        a = -bodies[k].mass / d**2.
                        # Separate into x and y components
                        theta = np.arctan2(dy, dx)
                        bodies[j].ax += a * np.cos(theta)
                        bodies[j].ay += a * np.sin(theta)
            # Update positions
            for j in range(len(bodies)):
                bodies[j].vx[i+1] += bodies[j].vx[i] + bodies[j].ax * delta_t
                bodies[j].vy[i+1] += bodies[j].vy[i] + bodies[j].ay * delta_t
                bodies[j].x[i+1] += bodies[j].x[i] + bodies[j].vx[i] * delta_t    +\
                                0.5 * bodies[j].ax * delta_t**2. 
                bodies[j].y[i+1] += bodies[j].y[i] + bodies[j].vy[i] * delta_t + 0.5 *\
                                bodies[j].ay * delta_t**2
        return bodies



    for i in range(10):
        # Set up celestial bodies
        sun     = CelestialBody(0., 0., 0., 0., 1.)
        earth   = CelestialBody(1., 0., 0., 6.33, 3.00e-6)

        bodies = [sun, earth]

        # Set up time info
        tf = 100. 
        delta_t = tf / 365.
        nsteps = int(tf / delta_t)

        # Orbit
        start = timer()
        bodies = orbit(bodies, delta_t, nsteps)
        end = timer()
        print('Time to run: %f' % (end - start)) 

该代码可以在没有numba的情况下运行。当我添加numba时,我可以同时调试我的类和函数,并且可以很好地运行,并提供了很好的速度。但是,当我尝试使用cache = True缓存jitt'ed函数时,出现KeyError:

File "/usr/local/lib/python3.6/dist-packages/numba/caching.py", line 482, in save
data_name = overloads[key]
KeyError: ((reflected list(instance.jitclass.CelestialBody#2cef1b8<x0:float64,
 y0:float64,vx0:float64,vy0:float64,mass:float64,ax:float64,ay:float64,
x:array(float64, 1d, A),y:array(float64, 1d, A),vx:array(float64, 1d, A),
vy:array(float64, 1d, A)>), float64, int64), ('x86_64-unknown-linux-gnu',
'skylake', '+adx,+aes,+avx,+avx2,-avx512bitalg,-avx512bw,-avx512cd,-avx512dq,
-avx512er,-avx512f,-avx512ifma,-avx512pf,-avx512vbmi,-avx512vbmi2,-avx512vl,
-avx512vnni,-avx512vpopcntdq,+bmi,+bmi2,-cldemote,+clflushopt,-clwb,-clzero,+cmov,
+cx16,+f16c,+fma,-fma4,+fsgsbase,-gfni,+invpcid,-lwp,+lzcnt,+mmx,+movbe,-movdir64b,
-movdiri,-mwaitx,+pclmul,-pconfig,-pku,+popcnt,-prefetchwt1,+prfchw,-ptwrite,
-rdpid,+rdrnd,+rdseed,-rtm,+sahf,+sgx,-sha,-shstk,+sse,+sse2,+sse3,+sse4.1,
+sse4.2,-sse4a,+ssse3,-tbm,-vaes,-vpclmulqdq,-waitpkg,-wbnoinvd,-xop,+xsave,
+xsavec,+xsaveopt,+xsaves'))

我意识到上面的大多数内容都是编译器标志,并且可能是不必要的,但是我不确定,因此我想将其包括在内。

还有一个泡菜错误:

_pickle.PicklingError: Can't pickle <class '__main__.CelestialBody'>: it's not the same object as __main__.CelestialBody

我尝试查看this question,但据我所知没有导入错误,而且我也没有弄错我要导入的任何模块。我也没有在jupyter笔记本中运行,而只是在终端上运行。我的猜测是,它与类“ signature”在编译之前和之后都有关系,而pickle对此更改感到困惑。不使用类时,我可以使缓存起作用。

我正在使用Python版本3.6.7,numpy版本1.15.4和numba版本0.42.1

所以,我的问题是什么导致此腌制错误阻止了缓存?谢谢!

0 个答案:

没有答案