如何使用ndarray参数和/或输出定义numba类?

时间:2013-10-01 14:33:10

标签: python numpy numba

我试过了:

import numpy as np
import numba

@numba.jit
class foo(object):
    @numba.void(numba.int32)
    def __init__(self, somenum):
        self.somenumarray = np.arange(somenum)

    @numba.jit('f8[:](f8[:])')
    def somemethod1(self, a):
        return self.somenumarray + a

使用@numba.double[:](numba.double[:])方法装饰器会导致错误。

1 个答案:

答案 0 :(得分:1)

可以使用numba.FunctionType

完成此操作
import numpy as np
import numba

bar = numba.FunctionType(return_type=numba.f8[:], args=[numba.f8[:]])

@numba.jit
class foo(object):
    @numba.FunctionType(return_type=numba.void, args=[numba.int32])
    def __init__(self, somenum):
        self.somenumarray = np.arange(somenum)

    @bar
    def somemethod1(self, a):
        return self.somenumarray + a

您可以稍后执行此操作:

quux = foo(3)
quux.somemethod1(np.arange(3))