Cython:声明类似列表的函数参数

时间:2018-07-10 15:15:32

标签: python cython

我正在尝试创建一个简单的cython模块,并遇到以下问题。我想创建一个像这样的函数:

cdef float calc(float[:] a1, float[:] a2):
    cdef float res = 0
    cdef int l = len(a2)
    cdef float item_a2
    cdef float item_a1

    for idx in range(l):
        if a2[idx] > 0:
            item_a2 = a2[idx]
            item_a1 = a1[idx]
            res += item_a2 * item_a1

    return res

函数执行时,a1和a2参数是python列表。因此我得到了错误:

  

TypeError:需要一个类似字节的对象,而不是“列表”

我只需要进行这样的计算,仅此而已。但是,如果我需要使用C来最大化速度,该如何定义输入参数float[:] a1float[:] a2? 也许有必要手动将列表转换为数组吗?

P.S。如果您还可以向我解释是否有必要显式声明cdef float item_a2来执行乘法(就性能而言)还是等同于result += a2[idx] * a1[idx]

,也将不胜感激

2 个答案:

答案 0 :(得分:1)

Cython答案

一种执行此操作的方法(如果您愿意使用numpy):

import numpy as np
cimport numpy as np

ctypedef np.npy_float FLOAT
ctypedef np.npy_intp INTP

cdef FLOAT calc(np.ndarray[FLOAT, ndim=1, mode='c'] a1, 
                np.ndarray[FLOAT, ndim=1, mode='c'] a2):
    cdef FLOAT res = 0
    cdef INTP l = a2.shape[0]
    cdef FLOAT item_a2
    cdef FLOAT item_a1

    for idx in range(l):
        if a2[idx] > 0:
            item_a2 = a2[idx]
            item_a1 = a1[idx]
            res += item_a2 * item_a1

    return res

这将要求您的数组使用np.float32 dtype。如果需要np.float64,则可以将FLOAT重新定义为np.float64_t

一个不请自来的建议... l对于变量来说是个坏名字,因为它看起来像一个数字。考虑将其重命名为length或类似名称。

带有Numpy的纯python

最后,似乎您要计算两个向量之间的点积,其中一个数组中的元素为正。您可以在这里非常有效地使用Numpy来获得相同的结果。

>>> import numpy as np
>>> a1 = np.array([0, 1, 2, 3, 4, 5, 6])
>>> a2 = np.array([1, 2, 0, 3, -1])
>>> a1[:a2.shape[0]].dot(np.maximum(a2, 0))
11

请注意,我添加了a1切片,因为您没有在Cython函数中检查长度相等性,而是使用了a2的长度。因此,我认为长度可能会有所不同。

答案 1 :(得分:1)

cdef float calc(float[:] a1, float[:] a2):

a1a2可以是supports the buffer protocol并且具有float类型的任何对象。最常见的示例是numpy数组或standard library array module。他们将不接受Python列表,因为Python列表不是有效填充到内存中的单一同质C类型,而是Python对象的集合。

要从Python列表中创建合适的对象,您可以执行以下任一操作:

numpy.array([1.0,2.0],dtype=numpy.float32)
array.array('f',[1.0,2.0])

(您可能要考虑使用double / float64而不是float来提高精度,但这是您的选择)

如果您不想创建像这样的数组对象,那么Cython将无济于事,因为普通列表无法提供太多的速度。

另一种建议的np.ndarray[FLOAT, ndim=1] a1语法回答了您已经在使用的memoryview语法的过时版本。使用它没有好处(有一些小缺点)。


result += a2[idx] * a1[idx]

很好-Cython知道a1a2的类型,因此无需创建临时中间变量。您可以获取带有cython -a filename.pyx的html高亮文件进行检查,这将有助于指示未加速零件的位置。