我有一个函数来计算dirichlet分布的条件(第k个alpha)对数似然。我用Cython编写并编译,但我的代码调用大约12M次,它似乎是瓶颈,所以我希望加快它。
cimport numpy as np
import numpy as np
import math
DTYPE = np.float64
ctypedef np.float64_t DTYPE_t
def logFullConAlphaK(np.ndarray p,np.ndarray alpha, np.int k):
assert p.dtype == np.float64 and alpha.dtype == np.float64
cdef double t1=sum(np.log(p))
cdef DTYPE_t y=((alpha[k-1]-1)*t1)-np.log(alpha[k-1])+(p.shape[0]*
(math.lgamma(sum(alpha))- math.lgamma(alpha[k-1])))
return y
我将Cython编译成我在代码中使用的.pyd文件。有关如何加快速度的想法吗?
由于
答案 0 :(得分:3)
1)通过声明输入数组和p.shape[0]
的数据类型和维度:
def logFullConAlphaK(np.ndarray[DTYPE_t, ndim=1] p,
np.ndarray[DTYPE_t, ndim=1] alpha, int k):
...
cdef int tmp
tmp = p.shape[0]
2)通过使用C函数而不是模块math
中的Python函数:
cdef extern from "math.h":
double log(double x) nogil
3)使用NumPy的np.ndarray.sum()
方法
4)使用Cython指令避免一些开销
共:
#cython: wraparound=False
#cython: boundscheck=False
#cython: cdivision=True
#cython: nonecheck=False
import math
cimport numpy as np
import numpy as np
cdef extern from "math.h":
double log(double x) nogil
DTYPE = np.float64
ctypedef np.float64_t DTYPE_t
def logFullConAlphaK(np.ndarray[DTYPE_t, ndim=1] p,
np.ndarray[DTYPE_t, ndim=1] alpha, int k):
assert p.dtype == np.float64 and alpha.dtype == np.float64
cdef double t1
cdef int tmp
t1 = np.log(p).sum()
tmp = p.shape[0]
cdef DTYPE_t y=((alpha[k-1]-1)*t1)-log(alpha[k-1])+(tmp*
(math.lgamma(alpha.sum()) - math.lgamma(alpha[k-1])))
return y
OP的原始解决方案@ cel的解决方案与我的一些性能比较:
In [2]: timeit solOP(a, b, 10)
1000 loops, best of 3: 273 µs per loop
In [3]: timeit solcel(a, b, 10)
10000 loops, best of 3: 30.5 µs per loop
In [4]: timeit solS(a, b, 10)
100000 loops, best of 3: 15.8 µs per loop
答案 1 :(得分:2)
拿这个(可能是完全不现实的)样本数据:
TypeError: undefined is not an object (evaluating 'b.call')
我得到以下时间:
n = 1000000
p = np.random.rand(n)
alpha = np.random.rand(n)
k = 12
- > %timeit logFullConAlphaK(p, alpha, k)
1 loops, best of 3: 174 ms per loop
- > %timeit logFullConAlphaK_opt(p, alpha, k)
此版本已经为您提供了一个数量级的速度。请注意,几乎所有加速都来自于使用内置100 loops, best of 3: 13.3 ms per loop
上的np.sum
。所有其他更改只是为了更清晰的代码,它们对速度没有影响。
sum