我应该如何使用数组(或元组)作为Numba类型字典的键和值?

时间:2019-08-15 18:20:13

标签: numba

我有以下代码尝试将键值对存储到numba字典中。 Numba的官方页面说,新类型的字典支持将array作为键,但是我无法使其正常工作。 错误消息指出密钥不能为哈希。 知道如何使它正常工作吗?

In [7]: from numba.typed import Dict 
   ...: from numba import types 
   ...: import numpy as np        

In [15]: dd = Dict.empty(key_type=types.int32[::1], value_type=types.int32[::1],)                                                                                                                                  

In [16]: key = np.asarray([1,2,3], dtype=np.int32)                                                                                                                                                                 

In [17]: dd[key] = key   

错误消息:

TypingError:在nopython模式管道中失败(步骤:nopython前端) 数组类型为(int32,1d,C)的未知属性'哈希'

编辑: 我可能错过了一些东西。我可以在解释器中使用type.UniTuple(没有@jit装饰器)。但是,当我将以下函数放入脚本a.py并使用命令“ python a.py”运行时,出现了UniTuple not found错误。

@jit(nopython=True)
def go_fast2(date, starttime, id, tt, result): # Function is compiled and runs in machine code
    prev_record = Dict.empty(key_type=types.UniTuple(types.int64, 2),  value_type=types.UniTuple(types.int64, 3),)
    for i in range(1, length):
        key = np.asarray([date[i], id[i]], dtype=np.int64)
        thistt = tt[i]
        thistime = starttime[i]
        if key in prev_record:
            prev_time = prev_record[key][0]
            prev_tt = prev_record[key][1]
            prev_res = prev_record[key][2]
            if thistt == prev_tt and thistime - prev_time <= 30 * 1000 * 1000: # with in a 10 seconds window
                result[i] = prev_res + 1
            else:
                result[i] = 0
            prev_record[key] = np.asarray((thistime, thistt, result[i]), dtype=np.int64)
        else:
            result[i] = 0
            prev_record[key] = np.asarray((thistime, thistt, result[i]), dtype=np.int64)
    return 

1 个答案:

答案 0 :(得分:1)

当前文档说:

  

可接受的键/值类型包括但不限于:unicode   字符串,数组,标量,元组。

这种措辞的确使您看起来可能可以使用数组作为键类型,但这是不正确的,因为数组是不可哈希的,因为它是可变的。它也不适用于标准的python dict。您可以将数组转换为元组,这将起作用:

dd = Dict.empty(
    key_type=types.UniTuple(types.int64, 3), 
    value_type=types.int64[::1],)
key = np.asarray([1,2,3], dtype=np.int64)
dd[tuple(key)] = key

请注意,您以前使用的int32 dtype在64位计算机上不起作用,因为在数组上调用tuple()时int32s的元组会自动转换为int64。

另一个问题是元组的大小是固定的,因此您不能使用任意大小的数组作为键。