Python和Numba:尽可能快地访问结构化numpy数组元素

时间:2018-09-19 15:33:32

标签: python arrays numpy numba

我有一个具有以下数据类型的大型结构化numpy数组:

> my_array.dtype
= dtype([('field1', '<i4', (32,)), ('field2', '<i4', (425,)), 
         ('field3', '<i4', (8021,))])

我的目标是尽可能快地访问任何给定的单个元素。如果按字段名称对数组进行切片,则Numba可以在非python模式下执行。

@numba.njit
def test_function(my_array_sliced):
    for i in range(len(my_array_sliced)):
        _ = my_array_sliced[i]
    return

> %timeit test_function(my_array['field1'])
  399 ns ± 10.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

这很好用,但是我意识到我可以通过使用'constant'字段名以不同的方式访问数组来达到更好的性能。

@numba.njit
def test_function2(my_array):
    for i in range(len(my_array)):
        _ = my_array[i]['field1']
    return

> %timeit test_function2(my_array)
  280 ns ± 5.88 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

399ns与280ns-差异很大!我真的不能说出为什么会有这样的区别。

如果我尝试通过在函数中添加参数来获得更好的性能,则会出现输入错误。

@numba.njit
def test_function3(my_array, v):
    for i in range(len(my_array)):
        _ = my_array[i][v]
    return


> %timeit test_function3(my_array, 0):
  TypingError: Failed at nopython (nopython frontend)
  Invalid usage of getitem with parameters (Record([('field1', '<i4', 
  (32,)), ('field2', '<i4', (425,)), ('field3', '<i4', (8021,))]), int64)

如果我将int参数替换为字符串参数(例如'field1'),也会发生类似的事情(即使如此,我也不应该首先使用字符串,因为Numba并不真正处理它们)。 / p>

因此,从性能角度考虑,目前最好的选择是为每个字段名称创建不同的功能。那显然是疯了。我应该如何修改代码以获得最佳性能?

谢谢!

0 个答案:

没有答案