在尝试通过使用numpy和优化没有任何进一步改进的情况下,我试图使用numba加快算法。
我有一个函数可以在一个大型的2倍嵌套循环中进行一些计算:
import random
from numba import njit
@njit()
def decide_if_vaild():
return bool(random.getrandbits(1))
@njit()
def decide_what_bin(bins):
return random.randint(0, bins-1)
@njit()
def foo(bins, loops):
results = [[] for _ in range(bins)]
for i in range(loops):
for j in range(i+1, loops):
happy = decide_if_vaild()
bin = decide_what_bin(bins)
if happy:
results[bin].append( (i,j) )
# or
# results[bin].append( [i,j] )
return results
if __name__ == '__main__':
x = foo(3,100)
如果我在上面运行此最小示例,我(如预期的那样)会遇到输入错误:
File "C:\Users\xxx\AppData\Local\Programs\Python\Python36\lib\site-packages\numba\typeinfer.py", line 104, in getone
assert self.type is not None
numba.errors.InternalError:
[1] During: typing of call at C:/Users/xxx/minimal_example.py (21)
--%<-----------------------------------------------------------------
File "minimal_example.py", line 21
问题是:“ results [bin] .append((i,j))”在这里尝试向列表中添加一个元组(也不适用于列表)。
箱的数量是预先知道的,但是有多少个元素(2元组或列表或np.array)取决于decision_if_vaild评估为True的频率,而我不知道它的频率和计算非常昂贵,我不知道其他解决方法。
有什么好主意,我如何在jitted函数中生成结果并将其返回,或者传递一个可以填充此函数的全局容器?
这可能归结为:
numba.errors.LoweringError: Failed at nopython (nopython mode backend)
list(list(list(int64))): unsupported nested memory-managed object
从https://github.com/numba/numba/issues/2560开始,在numba 0.39.0中解决了list(list(int64))(https://github.com/numba/numba/pull/2840)的类似问题
答案 0 :(得分:0)
我现在已经实施了以下解决方法,即使它不能完全回答问题,对于其他正在努力解决此问题的人来说,它可能也是一种合适的方法:
@njit()
def foo(bins, loops):
results = []
mapping = []
for i in range(loops):
for j in range(loops+1, size):
happy = decide_if_vaild()
bin = decide_what_bin(bins)
if happy:
results.append( (i,j) )
mapping.append( bin )
return results, mapping
这将返回一个元组列表(从numba 0.39.0开始受支持)和一个映射列表,其中mapping [i]包含结果[i]的bin。现在,jit编译器可以顺利运行,并且我可以在jit之外解压缩结果。