如何使用numba cfunc句柄scipy.integrate.dblquad

时间:2018-09-27 04:35:33

标签: python scipy numba integral

我已阅读How to pass additional parameters to numba cfunc passed as LowLevelCallable to scipy.integrate.quad

说实话,我的脑子还很乱。

现在我要包装https://gist.github.com/Vindaar/aab2926425400fc57274b521e80398dd

中提到的dblquad函数

这是代码

def get_integrand(*args):
    delta_y, sigma = args
    def integrand2(r, theta):
        # integrand is the integrand of the function which will be 2D integrated
        r_prime_2      = r**2 + delta_y**2 - 2*r*delta_y*np.cos(3.0/2.0*np.pi - theta)
        value          = 1/(2*np.pi*sigma**2)*np.exp(-r_prime_2/(2*sigma**2))*r
        return value

    return integrand2

nb_integrand = numba.cfunc("float64(float64, float64)")(get_integrand(delta_y, sigma)).ctypes


a = 1.0 # some constant
delta_y = 0.5 # some shift
sigma = 2.0 # and our gaussians sigma

t0 = time.clock()
for i in range(1000):
    res = dblquad(
                #~ scipy.LowLevelCallable(nb_integrand), # some python function
                nb_integrand, # some python function
                  0.0,       # outer integral lower
                  2.0*np.pi, # outer integral upper
                  lambda theta: 0.0, # inner integral lower
                  lambda theta: a,   # inner integral upper
                  args=(delta_y, sigma) # additional arguments
    )

t1 = time.clock()
print(res)

但我明白了

            Traceback (most recent call last):
              File "test.py", line 59, in <module>
                args=(delta_y, sigma) # additional arguments
              File "E:\prg\py\Anaconda3_64\lib\site-packages\scipy\integrate\quadpack.py", line 581, in dblquad
                opts={"epsabs": epsabs, "epsrel": epsrel})
              File "E:\prg\py\Anaconda3_64\lib\site-packages\scipy\integrate\quadpack.py", line 805, in nquad
                return _NQuad(func, ranges, opts, full_output).integrate(*args)
              File "E:\prg\py\Anaconda3_64\lib\site-packages\scipy\integrate\quadpack.py", line 860, in integrate
                **opt)
              File "E:\prg\py\Anaconda3_64\lib\site-packages\scipy\integrate\quadpack.py", line 341, in quad
                points)
              File "E:\prg\py\Anaconda3_64\lib\site-packages\scipy\integrate\quadpack.py", line 448, in _quad
                return _quadpack._qagse(func,a,b,args,full_output,epsabs,epsrel,limit)
              File "E:\prg\py\Anaconda3_64\lib\site-packages\scipy\integrate\quadpack.py", line 860, in integrate
                **opt)
              File "E:\prg\py\Anaconda3_64\lib\site-packages\scipy\integrate\quadpack.py", line 341, in quad
                points)
              File "E:\prg\py\Anaconda3_64\lib\site-packages\scipy\integrate\quadpack.py", line 448, in _quad
                return _quadpack._qagse(func,a,b,args,full_output,epsabs,epsrel,limit)
            ValueError: Invalid scipy.LowLevelCallable signature "double (double, double)". Expected one of: ['double (double)', 'double (int, double)', 'double (long, double)']

有什么提示吗?谢谢

0 个答案:

没有答案