在Numba中设置结构化数组字段

时间:2018-04-18 19:27:52

标签: python numpy numba

我想在Numba编译的nopython函数中设置NumPy结构化标量的整个字段。下面代码中的desired_fn是我想要做的简单示例,而working_fn是我当前如何完成此任务的一个示例。

import numpy as np
import numba as nb


test_numpy_dtype = np.dtype([("blah", np.int64)])
test_numba_dtype = nb.from_dtype(test_numpy_dtype)


@nb.njit
def working_fn(thing):
    for j in range(len(thing)):
        thing[j]['blah'] += j

@nb.njit
def desired_fn(thing):
    thing['blah'] += np.arange(len(thing))


a = np.zeros(3,test_numpy_dtype)
print(a)
working_fn(a)
print(a)
desired_fn(a)

运行desired_fn(a)产生的错误是:

numba.errors.InternalError: unsupported array index type const('blah') in [const('blah')]
[1] During: typing of staticsetitem at /home/sam/PycharmProjects/ChessAI/playground.py (938)

这是 性能关键代码所需要的,并且将运行数十亿次,因此无需使用这些类型的循环似乎至关重要。

2 个答案:

答案 0 :(得分:1)

以下作品(numba 0.37):

@nb.njit
def desired_fn(thing):
    thing.blah[:] += np.arange(len(thing))
    # or
    # thing['blah'][:] += np.arange(len(thing))

如果您主要使用数据列而不是行,则可以考虑使用其他数据容器。 numpy结构化数组的布局类似于结构的向量而不是数组的结构。这意味着当您想要更新blah时,在遍历数组时,您将穿过非连续的内存空间。

此外,对于任何代码优化,使用timeit或其他一些时序线束(删除jit代码所需的时间)来查看实际性能是多么值得。您可能会发现numba显式循环,而更详细的实际上可能比矢量化代码更快。

答案 1 :(得分:1)

没有numba,访问字段值并不比访问二维数组的列慢:

In [1]: arr2 = np.zeros((10000), dtype='i,i')
In [2]: arr2.dtype
Out[2]: dtype([('f0', '<i4'), ('f1', '<i4')])

修改字段:

In [4]: %%timeit x = arr2.copy()
   ...: x['f0'] += 1
   ...: 
16.2 µs ± 13.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

如果我将字段分配给新变量,则类似的时间:

In [5]: %%timeit x = arr2.copy()['f0']
   ...: x += 1
   ...: 
15.2 µs ± 14.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

如果构造一个相同大小的1d数组,速度会快得多:

In [6]: %%timeit x = np.zeros(arr2.shape, int)
   ...: x += 1
   ...: 
8.01 µs ± 15.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

但访问二维数组的列的时间相似:

In [7]: %%timeit x = np.zeros((arr2.shape[0],2), int)
   ...: x[:,0] += 1
   ...: 
17.3 µs ± 23.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)