Numba:用数组中的值替换键的快速方法

时间:2019-12-30 23:09:45

标签: python arrays performance numpy numba

我想用具有重复元素的大尺寸keysvalues替换array。我正在尝试numbanumpy映射方法。两种方法的代码如下。

import numpy as np
from numba import njit, prange

array1 = np.arange(150*150*150, dtype=int)
array2 = np.arange(150*150*150, dtype=int)
array = np.concatenate((array1, array2))

keys = np.arange(50)
values = -1 * np.arange(50)

## Numba Approach
@njit(parallel=True)
def numba_replace(array, keys, values):

    for i in prange(len(keys)):
        for j in prange(len(array)):
            if array[j] == keys[i]:
               array[j] = values[i]


## numpy approach
def numpy_replace(array, keys, values):

    mapp = np.arange(array.size)
    mapp[keys] = values
    mapped = mapp[array]

    return mapped

## Performance 
%%timeit
numba_replace(array, keys, values)
# 117 ms ± 969 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%%timeit
numpy_replace(array, keys, values)
# 61.2 ms ± 159 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

尽管numpy_replacenumba_replace快2倍,但我不喜欢使用它,因为我的数组大小很大(3000 x 3000 x 3000),并且numpy方法会增加new array内存使用情况。有什么方法可以使numba_replace更快,或者在处理过程中没有创建新数组的任何方法?

2 个答案:

答案 0 :(得分:1)

改进Numba方法(降低复杂性)

由于您只想更改相对较小的值,因此可以使用集合来确定是否必须更改实际的数组元素。 另外,您可以使用search_sorted获取正确的键,值对。对于这个小例子,差异并不大,但是如果问题规模增大,差异将变得更大。

实施

import numpy as np
from numba import njit, prange

@njit(parallel=True)
def numba_replace(array, keys, values):
    ind_sort=np.argsort(keys)
    keys_sorted=keys[ind_sort]
    values_sorted=values[ind_sort]
    s_keys=set(keys)

    for j in prange(array.shape[0]):
        if array[j] in s_keys:
            ind = np.searchsorted(keys_sorted,array[j])
            array[j]=values_sorted[ind]
    return array

时间

import numpy as np
from numba import njit, prange

array1 = np.arange(150*150*150, dtype=int)
array2 = np.arange(150*150*150, dtype=int)
array = np.concatenate((array1, array2))

#to get proper timings do nothing here
#changing the array in-place will obviously have 
#an influence on the timings, because there are no values to change in the second run
keys = np.arange(50)
values = np.arange(50)

%timeit numba_replace(array, keys, values)
# 20.1 ms ± 1.95 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit numpy_replace(array, keys, values)
# 51.3 ms ± 392 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

答案 1 :(得分:0)

我猜是这样

array[keys] = values

在numpy中完成工作,而无需创建任何新数组

编辑:仅检查该命令是否执行与您的numpy_replace函数相同的操作:

mapped = numpy_replace(array, keys, values)
array[keys] = values
print(all(mapped == array)) # --> True