如何使用Numba加快python的字典速度

时间:2019-12-26 16:09:33

标签: python arrays python-3.x dictionary numba

我需要在布尔值数组中存储一些单元格。最初,我使用numpy,但是当数组开始占用大量内存时,我就想到了将非零元素存储在以元组为键的字典中(因为它是可哈希的类型)。例如: {(0, 0, 0): True, (1, 2, 3): True}(这是“ 3D数组”中的两个单元格,它们的索引分别为0,0,0和1,2,3,但是维数事先未知,并且在运行算法时已定义)。 这很有帮助,因为非零单元格仅填充了数组的一小部分。

要编写此字典并从中获取值,我需要使用循环:

def fill_cells(indices, area_dict):
    for i in indices:
        area_dict[tuple(i)] = 1

def get_cells(indices, area_dict):
    n = len(indices)
    out = np.zeros(n, dtype=np.bool)
    for i in range(n):
        out[i] = tuple(indices[i]) in area_dict.keys()
    return out

现在我需要用Numba加快速度。 Numba不支持本机Python的dict(),因此我使用了numba.typed.Dict。 问题在于Numba想要在定义功能的阶段知道键的大小,因此我什至无法创建字典(键的长度是预先未知的,并且在调用函数时已定义):

@njit
def make_dict(n):
    out = {(0,)*n:True}
    return out

Numba无法正确推断字典键的类型并返回错误:

Compilation is falling back to object mode WITH looplifting enabled because Function "make_dict" failed type inference due to: Invalid use of Function(<built-in function mul>) with argument(s) of type(s): (tuple(int64 x 1), int64)

如果我在函数中将n更改为具体数字,它将起作用。我用这个技巧解决了它:

n = 10
s = '@njit\ndef make_dict():\n\tout = {(0,)*%s:True}\n\treturn out' % n
exec(s)

但是我认为这是错误的低效方式。而且我需要将我的fill_cells和get_cells函数与@njit装饰器一起使用,但是Numba返回相同的错误,因为我试图在此函数中从numpy数组创建元组。

我了解Numba的基本局限性(以及一般的编译方式),但是也许有某种方法可以加快功能,或者也许您对我的存储单元问题有另一种解决方案?

1 个答案:

答案 0 :(得分:0)

最终解决方案:

主要问题是Numba在定义创建元组的函数时需要知道元组的长度。诀窍是每次都重新定义功能。我需要使用定义功能的代码生成字符串,然后使用 exec()

运行它
n = 10
s = '@njit\ndef arr_to_tuple(a):\n\treturn (' + ''.join('a[%i],' % i for i in range(n)) + ')'
exec(s)

此后,我可以调用 arr_to_tuple(a)创建可以在另一个@njit装饰的函数中使用的元组。

例如,创建用于解决问题的元组键的空字典:

@njit
def make_empty_dict():
    tpl = arr_to_tuple(np.array([0]*5))
    out = {tpl:True}
    del out[tpl]
    return out

我在字典中写一个元素,因为它是Numba推断类型的方法之一。

此外,我需要使用问题中所述的 fill_cells get_cells 函数。这就是我用Numba重写它们的方式:

文字元素。刚刚将tuple()更改为arr_to_tuple():

@njit
def fill_cells_nb(indices, area_dict):
    for i in range(len(indices)):
        area_dict[arr_to_tuple(indices[i])] = True

从字典中获取元素需要一些令人毛骨悚然的代码:

@njit
def get_cells_nb(indices, area_dict):
    n = len(indices)
    out = np.zeros(n, dtype=np.bool_)
    for i in range(n):
        new_len = len(area_dict)
        tpl = arr_to_tuple(indices[i])
        area_dict[tpl] = True
        old_len = len(area_dict)
        if new_len == old_len:
            out[i] = True
        else:
            del area_dict[tpl]
    return out

我的Numba版本(0.46)不支持.contains(in)运算符和try-except构造。如果您拥有支持它的版本,则可以为其编写更多的“常规”解决方案。

因此,当我想检查字典中是否存在带有某些索引的元素时,我会记住它的长度,然后在字典中写一些带有所提及索引的内容。如果长度改变了,我得出结论该元素不存在。否则,该元素存在。看起来很慢,但事实并非如此。

速度测试:

解决方案出奇地快。与原生Python优化代码相比,我用%timeit进行了测试:

  1. arr_to_tuple()比常规 tuple()函数快5倍
  2. 具有numba的get_cells Python原生get_cells
  3. 相比,一个元素的速度快3倍,对于大型元素的速度快40倍。 与 Python编写的fill_cells 相比,
  4. 具有numba的fill_cells 一个元素快4倍,对于大型元素快40倍。