当包含自身的jitclass类时,如何使python类jitclass兼容?

时间:2019-12-06 14:31:25

标签: python class jit numba

我正在尝试创建一个类,该类可以是jitclass的一部分,但具有一些本身就是jitclass对象的属性。

例如,如果我有两个带有装饰器@jitclass的类,我想将它们实例化为第三类(combined)。

import numpy as np
from numba import jitclass
from numba import boolean, int32, float64,uint8

spec = [
    ('type' ,int32),
    ('val' ,float64[:]),
    ('result',float64)]

@jitclass(spec)
class First:
    def __init__(self):
        self.type = 1
        self.val = np.ones(100)
        self.result = 0.
    def sum(self):
        self.result = np.sum(self.val)

@jitclass(spec)
class Second:
    def __init__(self):
        self.type = 2
        self.val = np.ones(100)
        self.result = 0.
    def sum(self):
        self.result = np.sum(self.val)



@jitclass(spec)
class Combined:
    def __init__(self):
        self.List = []
        for i in range(10):
            self.List.append(First())
            self.List.append(Second())

    def sum(self):
        for i, c in enumerate(self.List):
            c.sum()
    def getresult(self):
        result = []
        for i, c in enumerate(self.List):
            result.append(c.result)
        return result


C = Combined()
C.sum()
result = C.getresult()
print(result)

在该示例中,我收到一个错误,因为numba无法确定self.List的类型,它是两个jitclass的组合。

如何使Combined类与jitclass兼容?

更新

它尝试了我在其他地方找到的东西:

import numpy as np
from numba import jitclass, deferred_type
from numba import boolean, int32, float64,uint8
from numba.typed import List

spec = [
    ('type' ,int32),
    ('val' ,float64[:]),
    ('result',float64)]

@jitclass(spec)
class First:
    def __init__(self):
        self.type = 1
        self.val = np.ones(100)
        self.result = 0.
    def sum(self):
        self.result = np.sum(self.val)



spec1 = [('ListA',  List(First.class_type.instance_type, reflected=True))]

@jitclass(spec1)
class Combined:
    def __init__(self):
        self.ListA = [First(),First()] 

    def sum(self):
        for i, c in enumerate(self.ListA):
            c.sum()
    def getresult(self):
        result = []
        for i, c in enumerate(self.ListA):
            result.append(c.result)
        return result


C = Combined()
C.sum()
result = C.getresult()
print(result)

但是我得到这个错误

List(First.class_type.instance_type)
TypeError: __init__() takes 1 positional argument but 2 were given

1 个答案:

答案 0 :(得分:2)

TL; DR:

  • 您可以引用jitclass中的其他jitclass,即使您有这些列表。您只需要更正命名空间numba.typed-> numba.types
  • 目前(从numba 0.46开始)无法在jitclass es或no-python numba.jit函数中包含异构列表。因此,您不能将FirstSecond的两个实例都附加到同一列表中。

解决numba.typed.List异常

您的更新几乎是正确的。您需要使用numba.types.List而不是numba.typed.List。区别有些细微,但是numba.types包含签名的类型,而numba.typed名称空间包含可以实例化并在代码中使用的类。

因此,如果您使用它,它将起作用:

spec1 = [('ListA',  nb.types.List(First.class_type.instance_type, reflected=True))]

更改该代码:

import numpy as np
import numba as nb

spec = [
    ('type', nb.int32),
    ('val', nb.float64[:]),
    ('result', nb.float64)
]

@nb.jitclass(spec)
class First:
    def __init__(self):
        self.type = 1
        self.val = np.ones(100)
        self.result = 0.
    def sum(self):
        self.result = np.sum(self.val)

spec1 = [('ListA',  nb.types.List(First.class_type.instance_type, reflected=True))]

@nb.jitclass(spec1)
class Combined:
    def __init__(self):
        self.ListA = [First(), First()] 
    def sum(self):
        for i, c in enumerate(self.ListA):
            c.sum()
    def getresult(self):
        result = []
        for i, c in enumerate(self.ListA):
            result.append(c.result)
        return result

C = Combined()
C.sum()
result = C.getresult()
print(result)

产生输出:[100.0, 100.0]

Intermezzo:在这里使用jitclass是否有意义?

不过要记住的是,普通的Python类可能比jitclass方法快(或快):

import numpy as np
import numba as nb

class First:
    def __init__(self):
        self.type = 1
        self.val = np.ones(100)
        self.result = 0.
    def sum(self):
        self.result = np.sum(self.val)

class Combined:
    def __init__(self):
        self.ListA = [First(), First()] 
    def sum(self):
        for i, c in enumerate(self.ListA):
            c.sum()
    def getresult(self):
        result = []
        for i, c in enumerate(self.ListA):
            result.append(c.result)
        return result

C = Combined()
C.sum()
C.getresult()

这只是出于好奇而已。但是对于生产而言,我将从纯Python + NumPy开始,仅在速度太慢时才应用numba,然后仅在瓶颈所在的部分上应用numba,并且仅当numba擅长优化这些功能时(numba目前是专用工具,而不是通用工具)。

带有numba的异构(混合类型)列表吗?

在无python(无对象)模式下使用numba时,您需要同类列表。据我所知,numba 0.46不支持在jitclasses或nopython-jit方法中包含不同类型对象的列表。这意味着您不能有一个包含FirstSecond实例的列表。

所以这行不通:

self.List.append(First())
self.List.append(Second())

来自numba docs

  

支持从JIT编译的函数创建和返回列表,以及所有方法和操作。 列表必须严格相同:Numba将拒绝任何包含不同类型对象的列表,即使这些类型兼容也是如此 [...]