我正在优化一些我拥有的代码,这些代码主要包含在单个python类中。它很少处理python对象,所以我认为使用Numba会是一个很好的选择,但是在创建对象的过程中我需要大量参数,而且我认为我不太了解Numba相对较新的dict支持(documentation here)。我拥有的参数都是单个浮点数或整数,并且将它们传递到对象中,进行存储并在整个代码运行过程中使用,例如:
import numpy as np
from numba import jitclass, float64
spec = [
('p', dict),
('shape', tuple), # the shape of the array
('array', float64[:,:]), # an array field
]
params_default = {
par_1 = 1,
par_2 = 0.5
}
@jitclass(spec)
class myObj:
def __init__(self,params = params_default,shape = (100,100)):
self.p = params
self.shape = shape
self.array = self.p['par_2']*np.ones(shape)
def inc_arr(self):
self.array += self.p['par_1']*np.ones(shape)
我想我对Numba对此有什么了解还很多。如果我想使用nopython模式使用Numba进行优化,是否需要将规范传递给jitclass装饰器?如何定义字典的规格?我还需要声明形状元组吗?我查看了在jitclass装饰器上找到的documentation以及dict numba文档,但不确定该怎么做。当我运行上面的代码时,出现以下错误:
TypeError: spec values should be Numba type instances, got <class 'dict'>
我是否需要以某种方式将dict元素包括在规范中?从文档中尚不清楚正确的语法是什么。
或者,有没有办法让Numba推断输入类型?
答案 0 :(得分:1)
spec
必须由 numba特定类型组成,而不是python类型!
因此,规范中的tuple
和dict
必须键入 numba类型(并且afaik仅允许同构字典)。
因此,您可以在here所示的jitted函数中指定params_default
字典,也可以显式键入数字字典as shown here。
在这种情况下,我将采用后一种方法:
import numpy as np
from numba import jitclass, float64
# Explicitly define the types of the key and value:
params_default = nb.typed.Dict.empty(
key_type=nb.typeof('par_1'),
value_type=nb.typeof(0.5)
)
# assign your default values
params_default['par_1'] = 1. # Same type required, thus setting to float
params_default['par_2'] = .5
spec = [
('p', nb.typeof(params_default)),
('shape', nb.typeof((100, 100))), # the shape of the array
('array', float64[:, :]), # an array field
]
@jitclass(spec)
class myObj:
def __init__(self, params=params_default, shape=(100, 100)):
self.p = params
self.shape = shape
self.array = self.p['par_2'] * np.ones(shape)
def inc_arr(self):
self.array += self.p['par_1'] * np.ones(shape)
正如已经指出的那样:字典是afaik的同质类型。因此,所有键/值都必须是同一类型。因此,将int
和float
存储在同一字典中将无法正常工作。