这是一个关于从Cython融合类型转换为C ++类型的一般性问题,我将用一个最小的例子来描述。考虑表面的C ++函数模板:
template <typename T>
void scale_impl(const T * x, T * y, const T a, const size_t N) {
for (size_t n = 0; n < N; ++n) {
y[n] = a*x[n];
}
}
我希望能够在任何类型和形状的任何numpy ndarray
上调用此函数。使用Cython,我们首先声明函数模板:
cdef extern:
void scale_impl[T](const T * x, T * y, const T a, const size_t N)
然后声明我们希望操作的有效标量类型:
ctypedef fused Scalar:
float
double
...
最后实现实际的Cython垫片:
def scale(ndarray[Scalar] x, Scalar a):
"""Scale an array x by the value a"""
cdef ndarray[Scalar] y = np.empty_like(x)
scale_impl(<Scalar *>x.data, <Scalar *>y.data, a, x.size)
return y
这不起作用有两个原因:
x
只能是1维,而不是任意(或至少很多)维度<Scalar *>
会引发错误,因为Scalar
实际上是一个Python对象显然可以明确地推断出这些专业化:
if Scalar is float:
scale_impl(<float *>x.data, <float *>y.data, a, x.size)
if Scalar is double:
scale_impl(<double *>x.data, <double *>y.data, a, x.size)
if Scalar is ...
但是这会产生一个组合数量的代码路径,我必须手工编写用于娱乐多个融合类型的函数,并创建了这种情况(我假设)引入融合类型以避免。
有没有办法将任意维度(在合理范围内)数组传递给Cython函数并让它推导出标量数据的指针类型?或者,采用这种功能最合理的折衷方案是什么?
答案 0 :(得分:1)
(另请参阅Using Cython to wrap a c++ template to accept any numpy array中给出的答案,这是一个非常类似的问题。)
使用表单&x[0]
而不是尝试投射x.data
解决了选择正确模板专精的问题。 2D阵列的问题有点复杂,因为阵列不能保证连续或有序。
我创建了一个在1D数组上执行实际工作的函数,并将其包装在一个根据需要展平的简单函数中:
def _scale_impl(Scalar[::1] x, Scalar a):
# the "::1" syntax ensures the array is actually continuous
cdef np.ndarray[Scalar,ndim=1] y = np.empty_like(x)
cdef size_t N = x.shape[0] # this seems to be necessary to avoid throwing off Cython's template deduction
scale_impl(&x[0],&y[0],a,N)
return y
def scale(x, a):
"""Scale an array x by the value a"""
y = _scale_impl(np.ravel(x),a)
return y.reshape(x.shape) # reshape needs to be kept out of Cython