我正在学习如何使用Numba通过jit和vectorize加速功能。这段代码的jit版本没有任何问题,但是我在vectorize中遇到了索引错误。我怀疑这个question's答案是正确的想法,即存在类型错误,但是我不确定改变索引的方向。下面包括我一直在使用的功能,该功能输出斐波那契数直至序列的选定索引。索引出了什么问题?如何纠正我的代码以解决这个问题?
from numba import vectorize
import numpy as np
from timeit import timeit
@vectorize
def fib(n):
'''
Adjusted from:
https://lectures.quantecon.org/py/numba.html
https://en.wikipedia.org/wiki/Fibonacci_number
https://www.geeksforgeeks.org/program-for-nth-fibonacci-number/
'''
if n == 1:
return np.ones(1)
elif n > 1:
x = np.empty(n)
x[0] = 1
x[1] = 1
for i in range(2,n):
x[i] = x[i-1] + x[i-2]
return x
else:
print('WARNING: Check validity of input.')
print(timeit('fib(10)', globals={'fib':fib}))
这将导致以下错误输出。
Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/llvmlite/ir/instructions.py", line 619, in __init__
typ = typ.elements[i]
AttributeError: 'PointerType' object has no attribute 'elements'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/galen/Projects/myjekyllblog/test_code/quantecon_2.py", line 27, in <module>
print(timeit('fib(10)', globals={'fib':fib}))
File "/usr/lib/python3.6/timeit.py", line 233, in timeit
return Timer(stmt, setup, timer, globals).timeit(number)
File "/usr/lib/python3.6/timeit.py", line 178, in timeit
timing = self.inner(it, self.timer)
File "<timeit-src>", line 6, in inner
File "/usr/local/lib/python3.6/dist-packages/numba/npyufunc/dufunc.py", line 166, in _compile_for_args
return self._compile_for_argtys(tuple(argtys))
File "/usr/local/lib/python3.6/dist-packages/numba/npyufunc/dufunc.py", line 188, in _compile_for_argtys
cres, actual_sig)
File "/usr/local/lib/python3.6/dist-packages/numba/npyufunc/ufuncbuilder.py", line 157, in _build_element_wise_ufunc_wrapper
cres.objectmode, cres)
File "/usr/local/lib/python3.6/dist-packages/numba/npyufunc/wrappers.py", line 220, in build_ufunc_wrapper
env=envptr)
File "/usr/local/lib/python3.6/dist-packages/numba/npyufunc/wrappers.py", line 130, in build_fast_loop_body
env=env)
File "/usr/local/lib/python3.6/dist-packages/numba/npyufunc/wrappers.py", line 23, in _build_ufunc_loop_body
store(retval)
File "/usr/local/lib/python3.6/dist-packages/numba/npyufunc/wrappers.py", line 126, in store
out.store_aligned(retval, ind)
File "/usr/local/lib/python3.6/dist-packages/numba/npyufunc/wrappers.py", line 276, in store_aligned
self.context.pack_value(self.builder, self.fe_type, value, ptr)
File "/usr/local/lib/python3.6/dist-packages/numba/targets/base.py", line 482, in pack_value
dataval = self.data_model_manager[ty].as_data(builder, value)
File "/usr/local/lib/python3.6/dist-packages/numba/datamodel/models.py", line 558, in as_data
elems = self._as("as_data", builder, value)
File "/usr/local/lib/python3.6/dist-packages/numba/datamodel/models.py", line 530, in _as
self.get(builder, value, i)))
File "/usr/local/lib/python3.6/dist-packages/numba/datamodel/models.py", line 558, in as_data
elems = self._as("as_data", builder, value)
File "/usr/local/lib/python3.6/dist-packages/numba/datamodel/models.py", line 530, in _as
self.get(builder, value, i)))
File "/usr/local/lib/python3.6/dist-packages/numba/datamodel/models.py", line 624, in get
name="extracted." + self._fields[pos])
File "/usr/local/lib/python3.6/dist-packages/llvmlite/ir/builder.py", line 911, in extract_value
instr = instructions.ExtractValue(self.block, agg, idx, name=name)
File "/usr/local/lib/python3.6/dist-packages/llvmlite/ir/instructions.py", line 622, in __init__
% (list(indices), agg.type))
TypeError: Can't index at [0] in i8*
答案 0 :(得分:1)
错误是因为您试图vectorize
可以说基本上无法向量化的功能。我认为您对@jit
和@vectorize
的工作方式感到困惑。为了加快您的功能,可以使用@jit
,而@vectorize
用于创建numpy通用函数。参见official documentation here:
使用vectorize(),将函数编写为对输入进行操作 标量,而不是数组。 Numba将生成周围的循环 (或内核)允许对实际输入进行有效的迭代。
因此,基本上不可能创建具有与fibonacci函数相同功能的numpy通用函数。如果您有兴趣,这是official documentation on universal functions的链接。
因此,要使用@vectorize
,您需要创建一个实际上可以用作numpy通用函数的函数。为了加快代码的速度,您只需要使用@jit
。