定义cython类中函数的参数和cython中的快速积分计算

时间:2014-06-06 15:37:39

标签: python arrays numpy scipy cython

我是cython的新手并尝试将python类转换为 cython 。我不知道如何在实例Da 中定义参数z,因为它可以处理 numpy.array 或只是一个{ {1}}数字。

float

我也想知道我是否正确地将所有参数转换为cython参数?我想改变我原来的python代码以提高计算速度。我认为我的代码中的一个瓶颈应该是cdef class Cosmology(object): cdef double omega_m, omega_lam, omega_c def __init__(self,double omega_m=0.3,double omega_lam=0.7): self.omega_m = omega_m self.omega_lam = omega_lam self.omega_c = (1. - omega_m - omega_lam) cpdef double a(self, double z): cdef double a return 1./(1+z) cpdef double E(self, double a): cdef double E return (self.omega_m*a**(-3) + self.omega_c*a**(-2) + self.omega_lam)**0.5 cpdef double __angKernel(self, double x): cdef __angKernel: """Integration kernel""" return self.E(x**-1)**-1 cpdef double Da(self, double z, double z_ref=0): cdef double Da if isinstance(z, np.ndarray): da = np.zeros_like(z) for i in range(len(da)): da[i] = self.Da(z[i], z_ref) return da else: if z < 0: raise ValueError("Redshift z must not be negative") if z < z_ref: raise ValueError("Redshift z must not be smaller than the reference redshift") d = integrate.quad(self.__angKernel, z_ref+1, z+1,epsrel=1.e-6, epsabs=1.e-12) rk = (abs(self.omega_c))**0.5 if (rk*d[0] > 0.01): if self.omega_c > 0: d[0] = sinh(rk*d[0])/rk if self.omega_c < 0: d[0] = sin(rk*d[0])/rk return d[0]/(1+z) integrate.quad中是否有替代此功能有助于加快代码的性能?

cython

如果我想将数组传递给cdef class halo_positions(object): cdef double x = None cdef double y = None def __init__(self,numpy.ndarray[double, ndim=1] positions): self.x = positions[0] self.y = positions[1] 实例,这是一种正确的方法吗?

1 个答案:

答案 0 :(得分:1)

如果你的类定义为cdef,那么它只能在Cython中使用(不能在Python中访问),因此对类方法使用cpdefdef是不必要的,效率也不高。您可以将它们全部转换为cdef

当您告知zdouble时,它只接受double。如果您希望此参数具有两种不同的类型,则应保持其类型未声明,但这会在zndarray时直接影响循环性能。

或者你可以使用double *并传递它的大小,当大小为1时它是一个双精度,当大小为>1一个数组时。功能是:

cdef double Da(self, int size, double *z, double z_ref=0):
    if size>1:
        da = np.zeros(size)
        for i in range(size):
            da[i] = self.Da(1, &z[i], z_ref)
        return da
    else:
        if z[0] < 0:
            raise ValueError("Redshift z must not be negative")
        if z[0] < z_ref:
            raise ValueError("Redshift z must not be smaller than the reference redshift")

        d = integrate.quad(self.__angKernel, z_ref+1, z[0]+1,
                           epsrel=1.e-6, epsabs=1.e-12)
        rk = (abs(self.omega_c))**0.5
        if (rk*d[0] > 0.01):
            if self.omega_c > 0:
                d[0] = sinh(rk*d[0])/rk
            if self.omega_c < 0:
                d[0] = sin(rk*d[0])/rk
        return d[0]/(1+z[0])