假设我有以下(MCVE ...)cython功能
cimport cython
from scipy.linalg.cython_blas cimport dnrm2
cpdef double func(int n, double[:] x):
cdef int inc = 1
return dnrm2(&n, &x[0], &inc)
然后,我无法在np.float32
数组x
上调用它。
如何让func
接受double[:]
或float[:]
,并拨打dnrm2
或snrm2
?我目前唯一的解决方案是拥有两个函数,这会产生大量重复的代码。
答案 0 :(得分:2)
您可以使用融合类型。请注意,以下内容无法在我的系统上进行编译,因为ddot
和sdot
显然需要5个参数:
# cython: infer_types=True
cimport cython
from scipy.linalg.cython_blas cimport ddot, sdot
ctypedef fused anyfloat:
double
float
cpdef anyfloat func(int n, anyfloat[:] x):
cdef int inc = 1
if anyfloat is double:
return ddot(&n, &x[0], &inc)
else:
return sdot(&n, &x[0], &inc)