Numpy数组作为查找表

时间:2014-03-25 06:59:16

标签: python numpy

相关但不同,恕我直言:

(1)numpy: most efficient frequency counts for unique values in an array

(2)Using Numpy arrays as lookup tables

设定:

import numpy as np
from scipy.stats import itemfreq

x = np.array([1,  1,  1,  2,  25000, 2,  2,  5,  1,  1])
fq = itemfreq(x)
fq.astype(int)
array([[    1,     5],
       [    2,     3],
       [    5,     1],
       [25000,     1]])

现在,我想使用fq作为查找表,并执行此操作:

res = magic_lookup_function(fq, x)
res
    array([5, 5, 5, 3, 1, 3, 3, 1, 5, 5])

如(1)和(2)中所建议的,我可以将fq转换为python字典,然后从那里查找,然后返回到np.array。但有没有更干净/更快/纯粹的numpy方式呢?

更新:另外,正如(2)中所建议的,我可以使用bincount,但我担心如果我的索引很大,那可能效率很低,例如: 〜25万。

谢谢!

更新的解决方案

正如@Jaime指出的那样(下图),np.unique最多在O(n log n)时间内对数组进行排序。所以我想知道,itemfreq引擎盖下会发生什么?结果itemfreq对数组进行排序,我假设它也是O(n log n):

In [875]: itemfreq??

def itemfreq(a):
... ... ...
    scores = _support.unique(a)
    scores = np.sort(scores)

这是一个时间示例

In [895]: import timeit

In [962]: timeit.timeit('fq = itemfreq(x)', setup='import numpy; from scipy.stats import itemfreq; x = numpy.array([ 1,  1,  1,  2, 250000,  2,  2,  5,  1,  1])', number=1000)
Out[962]: 0.3219749927520752

但似乎没有必要对数组进行排序。如果我们在纯python中执行它会发生什么。

In [963]: def test(arr):
   .....:     fd = {}
   .....:     for i in arr:
   .....:         fd[i] = fd.get(i,0) + 1
   .....:     return numpy.array([fd[j] for j in arr])

In [967]: timeit.timeit('test(x)', setup='import numpy; from __main__ import test; x = numpy.array([ 1,  1,  1,  2, 250000,  2,  2,  5,  1,  1])', number=1000)
Out[967]: 0.028257131576538086
哇,快10倍!

(至少,在这种情况下,数组不是太长,但可能包含大值。)

而且,仅供参考,正如我所怀疑的那样,使用np.bincount执行此操作对于较大的值来说效率很低:

In [970]: def test2(arr):
    bc = np.bincount(arr)
    return bc[arr]

In [971]: timeit.timeit('test2(x)', setup='import numpy; from __main__ import test2; x = numpy.array([ 1,  1,  1,  2, 250000,  2,  2,  5,  1,  1])', number=1000)
Out[971]: 0.0975029468536377

2 个答案:

答案 0 :(得分:3)

您可以使用numpy.searchsorted

def get_index(arr, val):                                                                
    index = np.searchsorted(arr, val)                                                            
    if arr[index] == val:                                                                        
        return index                                                                             

In [20]: arr = fq[:,:1].ravel()                                                                  

In [21]: arr
Out[21]: array([  1.,   2.,   5.,  25.])

In [22]: get_index(arr, 25)                                                                      
Out[22]: 3

In [23]: get_index(arr, 2)                                                                       
Out[23]: 1

In [24]: get_index(arr, 4)    #returns `None` for  item not found.  

答案 1 :(得分:0)

由于您的查找表不仅仅是任何查找表,而是频率列表,您可能需要考虑以下选项:

>>> x = np.array([1,  1,  1,  2,  25, 2,  2,  5,  1,  1])
>>> x_unq, x_idx = np.unique(x, return_inverse=True)
>>> np.take(np.bincount(x_idx), x_idx)
array([5, 5, 5, 3, 1, 3, 3, 1, 5, 5], dtype=int64)

即使您的查找表更复杂,即:

>>> lut = np.array([[ 1, 10],
...                 [ 2,  9],
...                 [ 5,  8],
...                 [25,  7]])

如果你能负担得起np.unique(它对数组进行排序,因此它是n log n)调用return_index,你可以使用小的连续整数作为索引,这通常可以做事情更容易,例如使用np.searchsorted你会这样做:

>>> np.take(lut[:, 1], np.take(np.searchsorted(lut[:, 0], x_unq), x_idx))
array([10, 10, 10,  9,  7,  9,  9,  8, 10, 10])