我正在njit中编写一个函数来加速非常缓慢的油藏操作优化代码。该功能根据油藏水平和浇口可用性返回溢出释放的最大值。我传入的参数大小指定了要计算的流的数量(在某些调用中它是一个,在某些调用中是一些)。我也传入一个numpy.zeros数组,然后我可以用函数输出填充。该函数的简化版本如下:
import numpy as np
from numba import njit
@njit(cache=True)
def fncMaxFlow(elev, flag, size, MaxQ):
if (flag == 1): # SPOG2 running
if size==0:
if (elev>367.28):
return 861.1
else: return 0
else:
for i in range(size):
if((elev[i]>367.28) & (elev[i]<385)):
MaxQ[i]=861.1
return MaxQ
else:
if size==0: return 0
else: return MaxQ
fncMaxFlow(np.random.randint(368, 380, 3), 1, 3, np.zeros(3))
我得到的错误:
Can't unify return type from the following types: array(float64, 1d, C), float64, int32
这是什么原因?是否有任何解决方法或我缺少的一些步骤,所以我可以使用numba来加快速度?这个函数和其他类似函数被称为数百万次,因此它们是计算效率的主要因素。任何建议都会有所帮助 - 我对python很新。
答案 0 :(得分:3)
numba函数中的变量必须具有一致的类型,包括返回变量。在您的代码中,您可以返回MaxQ
(数组),861.1(浮点数)或0(一个int)。
您需要重构此代码,以便无论代码路径如何,它始终返回一致的类型。
另请注意,在您将numpy数组与标量(elev > 367.28
)进行比较的几个地方,您得到的是一个布尔值数组,这将导致您遇到问题。由于这个原因,你的示例函数不能作为纯python函数运行(删除numba decorator)。