在cython中使用numpy masking

时间:2017-07-07 10:08:30

标签: python numpy cython

当我尝试将函数定义为cdef时,我正在尝试将我的一些Python代码转换为Cython并遇到一些问题。

大多数问题归结为掩盖不像Python那样工作。我想知道这是否是cdef的限制(如果我将其保留为def,则工作正常)或者是否有我可以做的事情。

例如这个方法

cdef func(double[:,:,:,:] arg1):
    mask = arg1 > 0
    ...

已经因编译错误而失败:

Error compiling Cython file:
------------------------------------------------------------ ...
    func (double[:,:,:,:] arg1):
        mask = arg1 > 0
                   ^
------------------------------------------------------------

cythonfile.pyx:43:20: Invalid types for '>' (double[:, :, :, :], long)

2 个答案:

答案 0 :(得分:1)

The docs consistently use np.ndarray[...] in function definitions所以我会将你的功能签名更改为

cdef func(np.ndarray[np.float_t, ndim=4] arg1):

此外,您正在将float数组与long integer常量进行比较。将其更改为

mask = arg1 > 0.

floatfloat进行比较。

答案 1 :(得分:1)

double[:,:,:,:]表示法指定参数将被“解释”为Typed Memoryview。这些支持很多操作,但不支持矢量化比较

然而,将内存视图解释为函数内的NumPy数组非常容易:

import numpy as np

cdef func(double[:,:,:,:] arg1):
    arg1arr = np.asarray(arg1)
    mask = arg1arr > 0.

甚至不需要副本,因此在内存视图上执行np.asarray基本上是“免费的”。这允许将内存视图的优点与NumPy阵列上可能的矢量化操作结合起来。

然而,对于矢量化操作,您不需要Cython,您可以在纯python函数中执行所有向量化操作,并且仅使用Cython进行“正常NumPy”函数无法实现的繁重操作。