numba njit在二维np.array索引上给出my和error

时间:2019-04-16 12:40:03

标签: python numba

我正在尝试使用包含我要的索引B(矩阵a的切片)的向量在njit函数中为二维矩阵D编制索引 这是一个最小的例子:

import numba as nb
import numpy as np

@nb.njit()
def test(N,P,B,D):
    for i in range(N):
        a = D[i,:]
        b =  B[i,a]
        P[:,i] =b

P = np.zeros((5,5))
B = np.random.random((5,5))*100
D = (np.random.random((5,5))*5).astype(np.int32)
print(D)
N = 5
print(P)
test(N,P,B,D)
print(P)

我在b = B[i,a]行出现了numba错误

File "dj.py", line 10:
def test(N,P,B,D):
    <source elided>
        a = D[i,:]
        b =  B[i,a]
        ^

This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.

我不明白我在这里做错了什么。 该代码无需使用@nb.njit()装饰器

即可工作

2 个答案:

答案 0 :(得分:1)

numba不支持与numpy相同的所有“花式索引”,在这种情况下,问题是选择具有a数组的数组元素。

对于您的特殊情况,由于您事先知道b的形状,因此可以这样解决:

import numba as nb
import numpy as np

@nb.njit
def test(N,P,B,D):
    b = np.empty(D.shape[1], dtype=B.dtype)

    for i in range(N):
        a = D[i,:]
        for j in range(a.shape[0]):
            b[j] = B[i, j]
        P[:, i] = b

答案 1 :(得分:1)

另一种解决方案是在调用 test 之前在 B 上应用交换轴并反转索引 (B[i,a] -> B[a,i])。我不知道这为什么有效,但这是实现:

import numba as nb
import numpy as np

@nb.njit()
def test(N,P,B,D):
    for i in range(N):
        a = D[i,:]
        b =  B[a,i]
        P[:, i] = b
    
P = np.zeros((5,5))
B = np.arange(25).reshape((5,5))
D = (np.random.random((5,5))*5).astype(np.int32)
N = 5
test(N,P,np.swapaxes(B, 0, 1), D)

顺便说一句,@chrisb给出的答案中,不是:b[j] = B[i, j]而是b[j] = B[i, a[j]]