如何在cython中的函数参数中键入函数

时间:2018-09-10 06:55:12

标签: python cython static-typing

我制作了一个用于在python中进行优化的函数(我们将其称为optimizer)。它要求将函数作为函数参数之一进行优化(我们将其称为objective)。 objective是一个函数,它接受一维np.ndarray并返回一个float数字(与C ++中的double相同)。

我已经阅读了此post,但是我不确定它是否确实与我的问题以及使用ctypedef int (*f_type)(int, str)时的问题相同,但是在编译过程中出现错误Cannot convert 'f_type' to Python object 。它仅适用于C函数吗?如何键入python函数?

编辑:我的代码如下:

cpdef optimizer(objective, int num_particle, int dim,
             np.ndarray[double, ndim=1] lower_bound,
             np.ndarray[double, ndim=1] upper_bound):

    cdef double min_value
    cdef np.ndarray[double, ndim=2] positions = np.empty((num_particle,dim), dtype=np.double)
    cdef np.ndarray[double, ndim=1] fitness = np.empty(num_particle, dtype=np.double)
    cdef int i, j

    # do lots of stuff not shown here
    # involve the following code:
    for i in range(num_particle):
        fitness[i] = objective(positions[i])

    return min_value

我想知道是否可以键入objective来使代码运行更快。

1 个答案:

答案 0 :(得分:3)

我收到错误消息

  

Cannot convert Python object argument to type 'f_type'

我认为

比您声称得到的要有意义得多-您正在尝试将Python对象传递给该函数。请确保您报告的错误消息是您的代码实际生成的错误消息。您对objective采用的类型的描述也与您显示的代码不匹配。


但是,通常来说:不,您不能给目标函数一个类型说明符来加速它。通用的Python可调用程序比C函数指针携带更多的信息(例如引用计数,任何闭包捕获变量的详细信息等)。

一种可能的替代方法是使用适当的cdef class函数从cdef继承,这样至少可以在特定情况下获得适当的性能:

# an abstract function pointer class
cdef class FPtr:
    cdef double function(self,double[:] x) except? 0.0:
        # I'm assuming you might want to pass exceptions back to Python - use 0.0 to indicate that there might have been an error
        raise NotImplementedError()

# an example class that inherits from the abstract pointer type    
cdef class SumSq(FPtr):
    cdef double function(self,double[:] x) except? 0.0:
        cdef double sum=0.0
        for i in range(x.shape[0]):
            sum += x[i]**2
        return sum

# an example class that just wraps a Python callable
# this will be no faster, but makes the code generically usable
cdef class PyFPtr(FPtr):
    cdef object f
    def __init__(self,f):
        self.f = f

    cdef double function(self,double[:] x) except? 0.0:
        return self.f(x) # will raise an exception if the types don't match

def example_function(FPtr my_callable):
    import numpy as np
    return my_callable.function(np.ones((10,)))

使用此example_function(SumSq())可以正常工作(并具有Cython速度); example_function(PyFPtr(lambda x: x[0]))可以正常工作(可调用项中没有Cython速度); example_function(PyFPtr(lambda x: "hello"))出现预期的类型错误。