由于数据类型,编译njit nopython版本的函数失败

时间:2017-07-28 18:46:15

标签: python performance jit numba

我正在njit中编写一个函数来加速非常缓慢的油藏操作优化代码。该功能根据油藏水平和浇口可用性返回溢出释放的最大值。我传入的参数大小指定了要计算的流的数量(在某些调用中它是一个,在某些调用中是一些)。我也传入一个numpy.zeros数组,然后我可以用函数输出填充。该函数的简化版本如下:

import numpy as np
from numba import njit

@njit(cache=True)
def fncMaxFlow(elev, flag, size, MaxQ):
    if (flag == 1): # SPOG2 running
        if size==0:
            if (elev>367.28):
                return 861.1 
            else: return 0
        else:
            for i in range(size):
                if((elev[i]>367.28) & (elev[i]<385)):
                    MaxQ[i]=861.1
            return MaxQ
    else:
        if size==0: return 0
        else: return MaxQ

fncMaxFlow(np.random.randint(368, 380, 3), 1, 3, np.zeros(3))

我得到的错误:

Can't unify return type from the following types: array(float64, 1d, C), float64, int32

这是什么原因?是否有任何解决方法或我缺少的一些步骤,所以我可以使用numba来加快速度?这个函数和其他类似函数被称为数百万次,因此它们是计算效率的主要因素。任何建议都会有所帮助 - 我对python很新。

1 个答案:

答案 0 :(得分:3)

numba函数中的变量必须具有一致的类型,包括返回变量。在您的代码中,您可以返回MaxQ(数组),861.1(浮点数)或0(一个int)。

您需要重构此代码,以便无论代码路径如何,它始终返回一致的类型。

另请注意,在您将numpy数组与标量(elev > 367.28)进行比较的几个地方,您得到的是一个布尔值数组,这将导致您遇到问题。由于这个原因,你的示例函数不能作为纯python函数运行(删除numba decorator)。