我正在尝试定义@njit
函数,该函数在几个点需要计算最小值。如下所示:
min(arg1, arg2,...)
然而,当我去计算迭代的最小值时,无论如何定义迭代,我似乎都有问题。我在下面用一个小函数重现了这个:
itrbl = [5.0, 0.4, 4.5, 3.5, 6.4]
@njit
def funct(itrbl):
return min(itrbl)
funct(itrbl)
并收到以下错误:
TypingError: Invalid usage of Function(<built-in function min>) with parameters (reflected list(float64))
* parameterized
当我将itrbl
结构化为数组时似乎并不喜欢......我如何构建itrbl
以使这个简单的函数有效?
答案 0 :(得分:0)
在njit
ted函数中,这些函数调用被这些函数的numba版本取代。因此,当您调用min
时,它不是Python的min
函数。
另外,numba通常针对与numpy.array
而非list
一起使用进行了优化。当您将list
传递给numba njit
ted函数(从Python)或从list
ted函数(返回给Python)返回njit
时,它必须转换(将整个列表复制到反映列表中,或者返回(Why is passing a list (of length n) to a numba nopython function an O(n) operation时将其复制)。因此,您应避免将Python列表传递给numba函数或将其返回。但是,在numba函数中创建列表并将其传递给其他numba函数不会受到这种转换开销的困扰。
但是,除了有关numba函数中的list
和min
的实际问题之外,还有一点。看起来min
的numba版本不支持列表/数组(但它确实支持元组)。
您可以轻松创建一个支持列表的numba函数:
from numba import njit
@njit
def my_min(lst):
"Don't pass empty arrays/lists to this function!"
min_ = lst[0]
for item in lst:
if item < min_:
min_ = item
return min_
似乎可行:
from numba import njit
@njit
def funct(itrbl):
return my_min(itrbl)
>>> funct([5.0, 0.4, 4.5, 3.5, 6.4])
0.4
但是,更好的方法是在数字函数中使用array
。原因是,根据我在numba上的经验,您经常遇到numba不支持列表的情况(例如使用min
)或在列表上调用该方法的速度相对较慢。另外,将列表传递给numba函数或从numba函数返回它们的开销可能是巨大的(而且是意外的)。