如何使cython函数接受float或double数组输入?

时间:2018-05-04 02:14:22

标签: python numpy scipy cython blas

假设我有以下(MCVE ...)cython功能

cimport cython

from scipy.linalg.cython_blas cimport dnrm2


cpdef double func(int n, double[:] x):
   cdef int inc = 1
   return dnrm2(&n, &x[0], &inc)

然后,我无法在np.float32数组x上调用它。

如何让func接受double[:]float[:],并拨打dnrm2snrm2?我目前唯一的解决方案是拥有两个函数,这会产生大量重复的代码。

1 个答案:

答案 0 :(得分:2)

您可以使用融合类型。请注意,以下内容无法在我的系统上进行编译,因为ddotsdot显然需要5个参数:

# cython: infer_types=True
cimport cython

from scipy.linalg.cython_blas cimport ddot, sdot

ctypedef fused anyfloat:
   double
   float

cpdef anyfloat func(int n, anyfloat[:] x):
   cdef int inc = 1
   if anyfloat is double:
      return ddot(&n, &x[0], &inc)
   else:
      return sdot(&n, &x[0], &inc)