Numba和二维numpy数组列表

时间:2019-12-04 14:25:35

标签: python list numpy numba

我上了一堂课

from numba import jitclass, int32, float32, types
from numba.typed import List

_spec = [
    ('Y_rf', types.List(float32[:, :])),
    ...
]

@jitclass(_spec)
class DensityRatioEstimation:
    def __init__(self, sigma):
        self.sigma = sigma
        self.Y_rf = [np.array([[0.]], dtype=float32)]

但是我无法使其工作。它总是因不同的错误而崩溃。现在的错误是:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Internal error at <numba.typeinfer.CallConstraint object at 0x00000277CBBBF550>.
Failed in nopython mode pipeline (step: nopython mode backend)


File "src\models\DDRE.py", line 26:
    def __init__(self, sigma):
        <source elided>
        self.sigma = sigma
        self.Y_rf = [np.array([[0.]], dtype=float32)]
        ^

[1] During: lowering "(self).Y_rf = $0.11" at D:\anomaly-detection\src\models\DDRE.py (26)
[2] During: resolving callee type: jitclass.DensityRatioEstimation#277c8a6cdd8<sigma:float32,Y_rf:list(array(float32, 2d, A)),Y_te:list(array(float32, 2d, A)),k:int32,alphas:array(float32, 1d, A),b:array(float32, 1d, A)>
[3] During: typing of call at <string> (3)

Enable logging at debug level for details.

File "<string>", line 3:
<source missing, REPL/exec in use?>

我还尝试使用List.empty_list(float32[:, :])中的numba.types.List代替[np.array([[0.]], dtype=float32)]。但这也不起作用。该如何解决?

1 个答案:

答案 0 :(得分:1)

您的代码段存在一个问题,您正在尝试使用Numba dtype创建一个Numpy数组。

np.array([[1, 2], [3, 4]], dtype=np.float32) # OK
np.array([[1, 2], [3, 4]], dtype=nb.float32) # Not OK

但是,主要问题是您需要使用numba.types.npytypes.Array指定列表的类型。这与使用float32([:,:])指定数组的函数签名不同。

import numba as nb
import numpy as np

_spec = [
    ('Y_rf', nb.types.List(nb.types.Array(nb.types.float32, 2, 'C'))),
    ('sigma', nb.types.int32)
]

@jitclass(_spec)
class DensityRatioEstimation:
    def __init__(self, sigma):
        self.sigma = sigma
        self.Y_rf = [np.array([[1, 2], [3, 4]], dtype=np.float32)]


dre = DensityRatioEstimation(1)
dre.Y_rf

输出

[array([[1., 2.],
        [3., 4.]], dtype=float32)]