在numba的jitclass中索引多维numpy数组

时间:2019-01-13 11:31:41

标签: python multidimensional-array indexing jit numba

我正在尝试将一个较小的多维数组插入到numba jitclass内的较大数组中。小数组设置为由索引列表定义的大数组的特定位置。

以下MWE显示了没有numba的问题-一切都按预期进行

counter

函数import numpy as np class NumbaClass(object): def __init__(self, n, m): self.A = np.zeros((n, m)) # solution 1 using pure python def nonNumbaFunction1(self, idx, values): self.A[idx[:, None], idx] = values # solution 2 using pure python def nonNumbaFunction2(self, idx, values): self.A[np.ix_(idx, idx)] = values if __name__ == "__main__": n = 6 m = 8 obj = NumbaClass(n, m) print(f'A =\n{obj.A}') idx = np.array([0, 2, 5]) values = np.arange(len(idx)**2).reshape(len(idx), len(idx)) print(f'values =\n{values}') obj.nonNumbaFunction1(idx, values) print(f'A =\n{obj.A}') obj.nonNumbaFunction2(idx, values) print(f'A =\n{obj.A}') nonNumbaFunction1在numba类中均不起作用。所以我目前的解决方案看起来像这样,我认为这不是很好

nonNumbaFunction2

所以我的问题是:

  • 有人知道在numba中进行此索引的解决方案吗,还是有另一个矢量化解决方案?
  • 是否有针对import numpy as np from numba import jitclass from numba import int64, float64 from collections import OrderedDict specs = OrderedDict() specs['A'] = float64[:, :] @jitclass(specs) class NumbaClass(object): def __init__(self, n, m): self.A = np.zeros((n, m)) # solution for numba jitclass def numbaFunction(self, idx, values): for i in range(len(values)): idxi = idx[i] for j in range(len(values)): idxj = idx[j] self.A[idxi, idxj] = values[i, j] if __name__ == "__main__": n = 6 m = 8 obj = NumbaClass(n, m) print(f'A =\n{obj.A}') idx = np.array([0, 2, 5]) values = np.arange(len(idx)**2).reshape(len(idx), len(idx)) print(f'values =\n{values}') obj.numbaFunction(idx, values) print(f'A =\n{obj.A}') 的更快的解决方案?

知道插入的数组很小(4x4到10x10)可能很有用,但是此索引出现在嵌套循环中,因此它也必须快速安静!后来我也需要为三维对象建立类似的索引。

1 个答案:

答案 0 :(得分:0)

由于numba的索引支持受到限制,我认为您比自己编写for循环还能做得更好。要使其在各个维度上通用,可以使用generated_jit装饰器进行特殊化。像这样:

def set_2d(target, values, idx):
    for i in range(values.shape[0]):
        for j in range(values.shape[1]):
            target[idx[i], idx[j]] = values[i, j]

def set_3d(target, values, idx):
    for i in range(values.shape[0]):
        for j in range(values.shape[1]):
            for k in range(values.shape[2]):
                target[idx[i], idx[j], idx[k]] = values[i, j, l]

@numba.generated_jit
def set_nd(target, values, idx):
    if target.ndim == 2:
        return set_2d
    elif target.ndim == 3:
        return set_3d

然后,可以在您的jitclass中使用

specs = OrderedDict()
specs['A'] = float64[:, :]

@jitclass(specs)
class NumbaClass(object):
    def __init__(self, n, m):
        self.A = np.zeros((n, m))
    def numbaFunction(self, idx, values):
        set_nd(self.A, values, idx)