以下是我的cython代码,目的是做一个bootstrap。
def boots(int trial, np.ndarray[double, ndim=2] empirical, np.ndarray[double, ndim=2] expected):
cdef int length = len(empirical)
cdef np.ndarray[double, ndim=2] ret = np.empty((trial, 100))
cdef np.ndarray[long] choices
cdef np.ndarray[double] m
cdef np.ndarray[double] n
cdef long o
cdef int i
cdef int j
for i in range(trial):
choices = np.random.randint(0, length, length)
m = np.zeros(100)
n = np.zeros(100)
for j in range(length):
o = choices[j]
m.__iadd__(empirical[o])
n.__iadd__(expected[o])
empirical_boot = m / length
expected_boot = n / length
ret[i] = empirical_boot / expected_boot - 1
ret.sort(axis=0)
return ret[int(trial * 0.025)].reshape((10,10)), ret[int(trial * 0.975)].reshape((10,10))
# test code
empirical = np.ones((40000, 100))
expected = np.ones((40000, 100))
%prun -l 10 boots(100, empirical,expected)
纯粹的python需要花费11秒才能获得精美的索引,无论我在cython中如何努力,它都保持不变。
np.random.randint(0, 40000, 40000)
需要1 ms,因此100x需要0.1秒。
np.sort(np.ones((40000, 100))
需要0.2秒。
因此,我觉得必须有办法改进boots
。
答案 0 :(得分:3)
您看到的主要问题是Cython仅针对类型化数组优化单项访问。这意味着您在NumPy中使用矢量化的代码中的每一行仍然涉及创建Python对象并与之交互。 你在那里的代码并不比纯Python版本快,因为它并没有真正以不同的方式进行任何计算。 您必须通过明确写出循环操作来避免这种情况。 以下是代码的修改版本,运行速度要快得多。
from numpy cimport ndarray as ar
from numpy cimport int32_t as int32
from numpy import empty
from numpy.random import randint
cimport cython
ctypedef int
# Notice the use of these decorators to tell Cython to turn off
# some of the checking it does when accessing arrays.
@cython.boundscheck(False)
@cython.wraparound(False)
def boots(int32 trial, ar[double, ndim=2] empirical, ar[double, ndim=2] expected):
cdef:
int32 length = empirical.shape[0], i, j, k
int32 o
ar[double, ndim=2] ret = empty((trial, 100))
ar[int32] choices
ar[double] m = empty(100), n = empty(100)
for i in range(trial):
# Still calling Python on this line
choices = randint(0, length, length)
# It was faster to compute m and n separately.
# I suspect that has to do with cache management.
# Instead of allocating new arrays, I just filled the old ones with the new values.
o = choices[0]
for k in range(100):
m[k] = empirical[o,k]
for j in range(1, length):
o = choices[j]
for k in range(100):
m[k] += empirical[o,k]
o = choices[0]
for k in range(100):
n[k] = expected[o,k]
for j in range(1, length):
o = choices[j]
for k in range(100):
n[k] += expected[o,k]
# Here I simplified some of the math and got rid of temporary arrays
for k in range(100):
ret[i,k] = m[k] / n[k] - 1.
ret.sort(axis=0)
return ret[int(trial * 0.025)].reshape((10,10)), ret[int(trial * 0.975)].reshape((10,10))
如果您想查看代码的哪些行涉及Python调用,Cython编译器可以生成一个html文件,显示哪些行调用Python。
此选项称为注释。
你使用它的方式取决于你如何编译你的cython代码。
如果您使用的是IPython笔记本,只需将--annotate
标志添加到Cython单元格魔术中。
您也可以从打开C编译器优化标志中受益。