如何从Numba实现C callable与nquad的高效集成?

时间:2017-08-22 16:54:22

标签: python numpy scipy numba

我需要在python中进行6D的数值积分。因为scipy.integrate.nquad函数很慢,所以我现在试图通过将整数定义为带有Numba的scipy.LowLevelCallable来加快速度。

通过复制给定here的示例,我能够在1D中使用scipy.integrate.quad执行此操作:

import numpy as np
from numba import cfunc
from scipy import integrate

def integrand(t):
    return np.exp(-t) / t**2

nb_integrand = cfunc("float64(float64)")(integrand)

# regular integration
%timeit integrate.quad(integrand, 1, np.inf)

10000次循环,最佳为3:每循环128μs

# integration with compiled function
%timeit integrate.quad(nb_integrand.ctypes, 1, np.inf)

100000个循环,最佳3:每个循环7.08μs

当我想用nquad做这个时,nquad文档说:

  

如果用户希望提高集成性能,那么f可能是a   scipy.LowLevelCallable和其中一个签名:

double func(int n, double *xx)
double func(int n, double *xx, void *user_data)
     

其中n是额外参数的数量,args是数组   附加参数的两倍,xx数组包含   坐标。 user_data是包含在中的数据   scipy.LowLevelCallable。

但是下面的代码给了我一个错误:

from numba import cfunc
import ctypes

def func(n_arg,x):
    xe = x[0]
    xh = x[1]
    return np.sin(2*np.pi*xe)*np.sin(2*np.pi*xh)

nb_func = cfunc("float64(int64,CPointer(float64))")(func)

integrate.nquad(nb_func.ctypes, [[0,1],[0,1]], full_output=True)

错误:quad:第一个参数是带有错误签名的ctypes函数指针

是否可以使用numba编译一个函数,该函数可以直接在代码中与nquad一起使用,而无需在外部文件中定义函数?

非常感谢你!

2 个答案:

答案 0 :(得分:4)

将功能包裹在scipy.LowLevelCallable中使nquad感到满意:

si.nquad(sp.LowLevelCallable(nb_func.ctypes), [[0,1],[0,1]], full_output=True)
# (-2.3958561404687756e-19, 7.002641250699693e-15, {'neval': 1323})

答案 1 :(得分:0)

您传递给double func(int n, double *xx)的功能的签名应为func。您可以为函数import numpy as np import scipy.integrate as si import numba from numba import cfunc from numba.types import intc, CPointer, float64 from scipy import LowLevelCallable def jit_integrand_function(integrand_function): jitted_function = numba.jit(integrand_function, nopython=True) @cfunc(float64(intc, CPointer(float64))) def wrapped(n, xx): return jitted_function(xx[0], xx[1]) return LowLevelCallable(wrapped.ctypes) @jit_integrand_function def func(xe, xh): return np.sin(2*np.pi*xe)*np.sin(2*np.pi*xh) print(si.nquad(func, [[0,1],[0,1]], full_output=True)) >>>(-2.3958561404687756e-19, 7.002641250699693e-15, {'neval': 1323}) 创建一个装饰器,如下所示:

class Article(models.Model):
    headline = models.CharField(max_length=100)
    pub_date = models.DateField()
    reporter = models.ForeignKey(Reporter, on_delete=models.CASCADE)