如何在类的成员函数上使用numba?

时间:2017-01-20 17:19:11

标签: python oop numba

我使用的是稳定版的Numba 0.30.1。

我可以这样做:

import numba as nb
@nb.jit("void(f8[:])",nopython=True)                             
def complicated(x):                                  
    for a in x:
        b = a**2.+a**3.

作为测试案例,加速是巨大的。但是如果我需要加速课堂内的功能,我不知道如何继续。

import numba as nb
def myClass(object):
    def __init__(self):
        self.k = 1
    #@nb.jit(???,nopython=True)                             
    def complicated(self,x):                                  
        for a in x:
            b = a**2.+a**3.+self.k

我对self对象使用什么numba类型?我需要在类中使用此函数,因为它需要访问成员变量。

3 个答案:

答案 0 :(得分:8)

我的处境非常相似,我找到了在类中使用Numba-JITed函数的方法。

诀窍是使用静态方法,因为这种方法不会被称为将对象实例添加到参数列表之前。无法访问self的不利之处是您不能使用在方法外部定义的变量。因此,您必须将它们从有权访问self的调用方法传递给静态方法。就我而言,我不需要定义包装方法。我只需要将要JIT编译的方法分成两个方法即可。

在您的示例中,解决方案是:

from numba import jit

class MyClass:
    def __init__(self):
        self.k = 1

    def calculation(self):
        k = self.k
        return self.complicated([1,2,3],k)

    @staticmethod
    @jit(nopython=True)                             
    def complicated(x,k):                                  
        for a in x:
            b = a**2 .+ a**3 .+ k

答案 1 :(得分:5)

您有几种选择:

使用jitclasshttp://numba.pydata.org/numba-doc/0.30.1/user/jitclass.html)“numba-ize”整个事情。

或者使成员函数成为包装器并通过以下方式传递成员变量:

import numba as nb

@nb.jit
def _complicated(x, k):
    for a in x:
        b = a**2.+a**3.+k

def myClass(object):
    def __init__(self):
        self.k = 1

    def complicated(self,x):                                  
        _complicated(x, self.k)

答案 2 :(得分:0)

documentation中,您可以:

使用@jitclass编译Python类

Numba支持通过Notes screen装饰器为类生成代码。可以使用此装饰器将类标记为优化,并指定每个字段的类型。我们将生成的类对象称为jitclass。 jitclass的所有方法都编译为nopython函数。 jitclass实例的数据作为C兼容结构分配在堆上,因此任何编译函数都可以绕过解释器直接访问基础数据。

这是jitclass的基础示例:

numba.jitclass()

在上面的示例中,提供了import numpy as np from numba import int32, float32 # import the types from numba.experimental import jitclass spec = [ ('value', int32), # a simple scalar field ('array', float32[:]), # an array field ] @jitclass(spec) class Bag(object): def __init__(self, value): self.value = value self.array = np.zeros(value, dtype=np.float32) @property def size(self): return self.array.size def increment(self, val): for i in range(self.size): self.array[i] += val return self.array @staticmethod def add(x, y): return x + y n = 21 mybag = Bag(n) 作为2元组的列表。元组包含字段名称和字段的Numba类型。

或者,您可以使用字典(最好将spec用于将字段名映射为类型。

该类的定义至少需要一个OrderedDict方法来初始化每个定义的字段。未初始化的字段包含垃圾数据。可以定义方法和属性(仅限getter和setter)。它们将被自动编译。

请注意,当前这是jitclass支持的早期版本,并且是试验性的。尚未公开或实现所有编译功能。