我正在学习cython以加速numpy。我编写了一个代码来了解如何优化numpy数组计算。 python代码是:
from numpy import *
def set_onsite(n):
a=linspace(0,n,n+1)
onsite=zeros([n+1,n+1],float)
for i in range(0,n+1):
onsite[i,i]=a[i]*a[i]
return onsite
然后,我试图对这段代码进行cython化:
import numpy as np
cimport numpy as np
cimport cython
import cython
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def set_onsite(np.int_t n):
cdef np.ndarray[double,ndim=1,mode='c'] a=np.linspace(0,n,n+1)
cdef np.ndarray[double,ndim=2,mode='c'] onsite=np.empty(n+1,n+1)
cdef np.int_t i
for i in range(0,n+1):
onsite[i,i]=a[i]*a[i]
return onsite
运行setup.py文件后,我得到了.so文件。我运行了代码%timeit myfile.set_onsite(10000)
,但是IPython显示了
TypeError:数据类型不理解
所以有人能告诉我这里发生了什么吗? 我多次检查了我的代码,但我没有弄清楚问题出在哪里。
答案 0 :(得分:4)
问题与cython无关;只是np.empty
期望第一个参数是作为整数的int或元组给出的形状。第二个参数被解释为dtype:
In [19]: np.empty(5,5)
TypeError: data type not understood
而np.empty((5,5))
返回一个空的数组形状(5,5)。
所以改为使用
cdef np.ndarray[double,ndim=2,mode='c'] onsite=np.empty((n+1,n+1))
请注意n+1, n+1
周围的两组括号。或者,使用np.zeros
而不是np.empty
来使Cython函数与Python函数匹配。
PS:在调试Python时,不仅要注意错误消息,还要注意引发异常的行:
File "comp.pyx", line 13, in comp.set_onsite (comp.c:1290)
cdef np.ndarray[double,ndim=2,mode='c'] onsite=np.empty(n+1,n+1)
TypeError: data type not understood