我只是想加速我在numpy中编写的数值算法。关键部分是计算对数似然函数(两个截断的正常CDF之间的差异)。我的功能很慢(每个循环31.9毫秒),我需要每次迭代运行2000次。
我试图使用scipy的“norm.cdf”函数而不是“ecfc”。但它更慢。我也尝试过Numba包中的“@jit”。但它也比原始代码慢。
我想也许我需要使用Cython。但我对C.几乎一无所知C.我试图从Cython for numpy users网页上学习Cython,但这对我来说真的很难。
有人可以帮我在Cython中重写代码吗?或者建议我如何更快地写它?
import numpy as np
from scipy.special import erfc
# The bloody function for calculating the difference between two truncated normal CDFs
def my_loglikelihood2(x,b,c,z):
log_likelihood=np.zeros(np.shape(z)[0])
log_likelihood[x==1]=np.log(0.5*erfc(-(c[1]-np.dot(z[x==1,:],b)) / np.sqrt(2.)) - 0.5*erfc(-(c[0]-np.dot(z[x==1],b)) / np.sqrt(2.)))
log_likelihood[x==2]=np.log(0.5*erfc(-(c[2]-np.dot(z[x==2,:],b)) / np.sqrt(2.)) - 0.5*erfc(-(c[1]-np.dot(z[x==2],b)) / np.sqrt(2.)))
log_likelihood[x==3]=np.log(0.5*erfc(-(c[3]-np.dot(z[x==3,:],b)) / np.sqrt(2.)) - 0.5*erfc(-(c[2]-np.dot(z[x==3],b)) / np.sqrt(2.)))
return log_likelihood
# generate random values
x=np.random.randint(low=1, high=4, size=50000)
b=np.random.normal(0,1,70)
c=np.array([-999,-1,1,999],dtype='f')
z=np.random.multivariate_normal(np.zeros(70), np.eye(70), 50000)
%timeit my_loglikelihood2(x,b,c,z)
# 10 loops, best of 3: 31.9 ms per loop :(
更新1 - 根据建议@jackvdp。它已加入4.5倍。但我仍在寻找更快的代码:
def up_cutoff(x,c):
x[x==1]=c[1]
x[x==2]=c[2]
x[x==3]=c[3]
return x
def low_cutoff(x,c):
x[x==1]=c[0]
x[x==2]=c[1]
x[x==3]=c[2]
def my_loglikelihood2(x,b,low_c,up_c,z):
up_c=up_cutoff(x,c)
low_c=low_cutoff(x,c)
return np.log(0.5*erfc(-(up_c-np.dot(z,b)) / np.sqrt(2.)) - 0.5*erfc(-(low_c-np.dot(z,b)) / np.sqrt(2.)))
%timeit my_loglikelihood2(x,b,low_c,up_c,z)
100 loops, best of 3: 6.58 ms per loop
更新2 - 根据建议@DSM。用zdotb = z.dot(b)替换np.dot(z,b)。改善了1.5ms
def my_loglikelihood2(x,b,low_c,up_c,z):
up_c=up_cutoff(x,c)
low_c=low_cutoff(x,c)
zdotb = z.dot(b)
return np.log(0.5*erfc(-(up_c-zdotb) / np.sqrt(2.)) - 0.5*erfc(-(low_c-zdotb) / np.sqrt(2.)))
%timeit my_loglikelihood2(x,b,low_c,up_c,z)
100 loops, best of 3: 5.02 ms per loop
答案 0 :(得分:1)
如果您的代码由于Python中的循环而变慢,那么将其移植到Cython可以看到很大的改进。但是你的样本只调用了现有的numpy / scipy函数六次。
主要是致电np.log, erfc, np.dot, np.sqrt
。我不确定erfc
但其他人已经使用了编译代码。 Cython不接触那些。
我们可以检查erfc
。
但最好的办法是用更大的数组调用此代码。