我正在尝试将一个较小的多维数组插入到numba jitclass内的较大数组中。小数组设置为由索引列表定义的大数组的特定位置。
以下MWE显示了没有numba的问题-一切都按预期进行
counter
函数import numpy as np
class NumbaClass(object):
def __init__(self, n, m):
self.A = np.zeros((n, m))
# solution 1 using pure python
def nonNumbaFunction1(self, idx, values):
self.A[idx[:, None], idx] = values
# solution 2 using pure python
def nonNumbaFunction2(self, idx, values):
self.A[np.ix_(idx, idx)] = values
if __name__ == "__main__":
n = 6
m = 8
obj = NumbaClass(n, m)
print(f'A =\n{obj.A}')
idx = np.array([0, 2, 5])
values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
print(f'values =\n{values}')
obj.nonNumbaFunction1(idx, values)
print(f'A =\n{obj.A}')
obj.nonNumbaFunction2(idx, values)
print(f'A =\n{obj.A}')
和nonNumbaFunction1
在numba类中均不起作用。所以我目前的解决方案看起来像这样,我认为这不是很好
nonNumbaFunction2
所以我的问题是:
import numpy as np
from numba import jitclass
from numba import int64, float64
from collections import OrderedDict
specs = OrderedDict()
specs['A'] = float64[:, :]
@jitclass(specs)
class NumbaClass(object):
def __init__(self, n, m):
self.A = np.zeros((n, m))
# solution for numba jitclass
def numbaFunction(self, idx, values):
for i in range(len(values)):
idxi = idx[i]
for j in range(len(values)):
idxj = idx[j]
self.A[idxi, idxj] = values[i, j]
if __name__ == "__main__":
n = 6
m = 8
obj = NumbaClass(n, m)
print(f'A =\n{obj.A}')
idx = np.array([0, 2, 5])
values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
print(f'values =\n{values}')
obj.numbaFunction(idx, values)
print(f'A =\n{obj.A}')
的更快的解决方案?知道插入的数组很小(4x4到10x10)可能很有用,但是此索引出现在嵌套循环中,因此它也必须快速安静!后来我也需要为三维对象建立类似的索引。
答案 0 :(得分:0)
由于numba的索引支持受到限制,我认为您比自己编写for循环还能做得更好。要使其在各个维度上通用,可以使用generated_jit
装饰器进行特殊化。像这样:
def set_2d(target, values, idx):
for i in range(values.shape[0]):
for j in range(values.shape[1]):
target[idx[i], idx[j]] = values[i, j]
def set_3d(target, values, idx):
for i in range(values.shape[0]):
for j in range(values.shape[1]):
for k in range(values.shape[2]):
target[idx[i], idx[j], idx[k]] = values[i, j, l]
@numba.generated_jit
def set_nd(target, values, idx):
if target.ndim == 2:
return set_2d
elif target.ndim == 3:
return set_3d
然后,可以在您的jitclass中使用
specs = OrderedDict()
specs['A'] = float64[:, :]
@jitclass(specs)
class NumbaClass(object):
def __init__(self, n, m):
self.A = np.zeros((n, m))
def numbaFunction(self, idx, values):
set_nd(self.A, values, idx)