我有一个具有以下数据类型的大型结构化numpy数组:
> my_array.dtype
= dtype([('field1', '<i4', (32,)), ('field2', '<i4', (425,)),
('field3', '<i4', (8021,))])
我的目标是尽可能快地访问任何给定的单个元素。如果按字段名称对数组进行切片,则Numba可以在非python模式下执行。
@numba.njit
def test_function(my_array_sliced):
for i in range(len(my_array_sliced)):
_ = my_array_sliced[i]
return
> %timeit test_function(my_array['field1'])
399 ns ± 10.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
这很好用,但是我意识到我可以通过使用'constant'字段名以不同的方式访问数组来达到更好的性能。
@numba.njit
def test_function2(my_array):
for i in range(len(my_array)):
_ = my_array[i]['field1']
return
> %timeit test_function2(my_array)
280 ns ± 5.88 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
399ns与280ns-差异很大!我真的不能说出为什么会有这样的区别。
如果我尝试通过在函数中添加参数来获得更好的性能,则会出现输入错误。
@numba.njit
def test_function3(my_array, v):
for i in range(len(my_array)):
_ = my_array[i][v]
return
> %timeit test_function3(my_array, 0):
TypingError: Failed at nopython (nopython frontend)
Invalid usage of getitem with parameters (Record([('field1', '<i4',
(32,)), ('field2', '<i4', (425,)), ('field3', '<i4', (8021,))]), int64)
如果我将int参数替换为字符串参数(例如'field1'),也会发生类似的事情(即使如此,我也不应该首先使用字符串,因为Numba并不真正处理它们)。 / p>
因此,从性能角度考虑,目前最好的选择是为每个字段名称创建不同的功能。那显然是疯了。我应该如何修改代码以获得最佳性能?
谢谢!