Numba中的TypingError 0.37

时间:2018-04-15 13:06:39

标签: python python-2.7 numpy jit numba

我正在尝试优化Python代码以计算以下公式: erez_photo

其中phi是2D数组,phi_i是1D数组。 我为它构建了一个代码,并尝试使用jit装饰器,但它给了我TypingError。这是我使用的代码:

import numpy as np
from numba import jit
@jit(nopython=True)
def calcAlpha(phi,fix_phis):
    phi_sq = phi**2
    fix_phis_sq = fix_phis**2
    F = []
    for l,phi_l_sq in enumerate(fix_phis_sq):
        F.append(2.0*phi_sq/(phi_sq-phi_l_sq))
#        print F[l]
        for j,phi_j_sq in enumerate(fix_phis_sq):
            if j != l:
                F[l]*=(phi_sq - phi_j_sq)/(phi_l_sq + phi_j_sq)
                F[l]*=(phi_l_sq + phi_j_sq)/(phi_sq + phi_j_sq)
    return np.array(F)
fix_sigmas=np.linspace(0.1,1,8)
sigma = np.random.random((252,252))

尝试运行该函数会显示以下消息:

In [7]: fout=calcAlpha(sigma,fix_sigmas)
---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
<ipython-input-7-88b11ed9cd73> in <module>()
----> 1 fout=calcAlpha(sigma,fix_sigmas)

../anaconda2/lib/python2.7/site-packages/numba/dispatcher.pyc in _compile_for_args(self, *args, **kws)
    328                                 for i, err in failed_args))
    329                 e.patch_message(msg)
--> 330             raise e
    331 
    332     def inspect_llvm(self, signature=None):

TypingError: Caused By:
Traceback (most recent call last):
  File "../anaconda2/lib/python2.7/site-packages/numba/compiler.py", line 240, in run
    stage()
  File "../anaconda2/lib/python2.7/site-packages/numba/compiler.py", line 454, in stage_nopython_frontend
    self.locals)
  File "../anaconda2/lib/python2.7/site-packages/numba/compiler.py", line 881, in type_inference_stage
    infer.propagate()
  File "../anaconda2/lib/python2.7/site-packages/numba/typeinfer.py", line 846, in propagate
    raise errors[0]
TypingError: Invalid usage of Function(<built-in function array>) with parameters (list(array(float64, 2d, C)))
 * parameterized
In definition 0:
    TypingError: array(float64, 2d, C) not allowed in a homogenous sequence
...

Failed at nopython (nopython frontend)
Invalid usage of Function(<built-in function array>) with parameters (list(array(float64, 2d, C)))
 * parameterized
In definition 0:
    TypingError: array(float64, 2d, C) not allowed in a homogenous sequence

1 个答案:

答案 0 :(得分:1)

不支持将数组附加到列表中,如果可以避免,将来不支持。 同时写出所有循环是值得推荐的,以便使用Numba获得最佳性能。

示例

@nb.njit(fastmath=True,parallel=True)
def calcAlpha(phi,fix_phis):
    phi_sq = phi*phi
    fix_phis_sq = fix_phis*fix_phis #1d
    F = np.zeros((fix_phis_sq.shape[0],phi.shape[0],phi.shape[1]),dtype=phi.dtype)

    for l in nb.prange(fix_phis_sq.shape[0]):
        for x in range(phi_sq.shape[0]):
          for y in range(phi_sq.shape[1]):
            F[l,x,y]=2.0*phi_sq[x,y]/(phi_sq[x,y]-fix_phis_sq[l])

        #if fix_phis doesn't get bigger, but phi does parallelize this loop
        for j in range(fix_phis_sq.shape[0]):
            if j != l:
                for x in range(phi_sq.shape[0]):
                    for y in range(phi_sq.shape[1]):
                        F[l,x,y]*=(phi_sq[x,y] - fix_phis_sq[j])/(fix_phis_sq[l] + fix_phis_sq[j])
                        F[l,x,y]*=(fix_phis_sq[l] + fix_phis_sq[j])/(phi_sq[x,y] + fix_phis_sq[j])
    return F

大多数加速(我的Quadcore-i7上的7倍)来自并行化。