python numpy加速2d重复搜索

时间:2017-09-18 17:01:06

标签: python performance numpy

我需要在2d numpy数组中找到重复项。结果,我想要一个与输入相同长度的列表,该列表指向第一次出现的相应值。例如,数组[[1,0,0],[1,0,0],[2,3,4]]有两个相等的元素0和1.该方法应该返回[0,0,2](参见以下代码中的示例)。 以下代码正在运行,但对于大型数组而言速度很慢。

import numpy as np


def duplicates(ar):
    """
    Args:
        ar (array_like): array

    Returns:
        list of int: int is pointing to first occurence of unique value
    """
    # duplicates array:
    dup = np.full(ar.shape[0], -1, dtype=int)
    for i in range(ar.shape[0]):
        if dup[i] != -1:
            # i is already found to be a
            continue
        else:
            dup[i] = i
        for j in range(i + 1, ar.shape[0]):
            if (ar[i] == ar[j]).all():
                dup[j] = i
    return dup


if __name__ == '__main__':
    n = 100
    # shortest extreme for n points
    a1 = np.array([[0, 1, 2]] * n)
    assert (duplicates(a1) == np.full(n, 0)).all(), True

    # longest extreme for n points
    a2 = np.linspace(0, 1, n * 3).reshape((n, 3))
    assert (duplicates(a2) == np.arange(0, n)).all(), True

    # test case
    a3 = np.array([[1, 0, 0], [1, 0, 0], [2, 3, 4]])
    assert (duplicates(a3) == [0, 0, 2]).all(), True

知道如何加快进程(例如避免第二次for循环)或替代实现吗? 干杯

3 个答案:

答案 0 :(得分:2)

您正在做的事情要求您在每个可能的配对中比较N行,每行长度为M。这意味着在没有重复的情况下,它最多可以缩放为O(N^2 * M)

更好的方法是散列每一行。如果哈希缩放所需的时间为O(M),那么这应该缩放为O(N * M)。你可以用字典做到这一点:

def duplicates(ar):
    """
    Args:
        ar (array_like): array

    Returns:
        list of int: int is pointing to first occurence of unique value
    """
    first_occurence = {}
    # duplicates array:
    dup = np.zeros(ar.shape[0], dtype=int)
    for i in range(ar.shape[0]):
        as_tuple = tuple(ar[i])
        if as_tuple not in first_occurence:
            first_occurence[as_tuple] = i
        dup[i] = first_occurence[as_tuple]
    return dup

答案 1 :(得分:1)

这是一种矢量化方法 -

def duplicates_1(a):
    sidx = np.lexsort(a.T)
    b = a[sidx]

    grp_idx0 = np.flatnonzero((b[1:] != b[:-1]).any(1))+1
    grp_idx = np.concatenate(( [0], grp_idx0, [b.shape[0] ] ))
    ids = np.repeat(range(len(grp_idx)-1), np.diff(grp_idx))
    sidx_mapped = argsort_unique(sidx)
    ids_mapped = ids[sidx_mapped]

    grp_minidx = sidx[grp_idx[:-1]]
    out = grp_minidx[ids_mapped]
    return out 

使用array-view概念使我们能够在1D级别工作,这是对第一种方法的修改 -

def duplicates_1_view1D(a):
    a1D = view1D(a)
    sidx0 = a1D.argsort()
    b0 = a1D[sidx0]

    N = len(b0)
    grp_idx0 = np.concatenate(( [0], np.flatnonzero(b0[1:] != b0[:-1])+1, [N] ))
    ids0 = np.repeat(range(len(grp_idx0)-1), np.diff(grp_idx0))
    sidx_mapped0 = argsort_unique(sidx0)
    ids_mapped0 = ids0[sidx_mapped0]

    grp_minidx0 = sidx0[grp_idx0[:-1]]
    out0 = grp_minidx0[ids_mapped0]
    return out0 

助手功能 -

# https://stackoverflow.com/a/44999009/ @Divakar
def view1D(a): # a is array
    a = np.ascontiguousarray(a)
    void_dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
    return a.view(void_dt).ravel()

# https://stackoverflow.com/a/43411559/ @Divakar
def argsort_unique(idx):
    n = idx.size
    sidx = np.empty(n,dtype=int)
    sidx[idx] = np.arange(n)
    return sidx

答案 2 :(得分:1)

我为Divakar和Jeremy的答案计算了我的代码示例中标有"#shortest extreme for n points"的两个测试用例。和#34; #n最长的n点"。所有答案都会产生预期的结果并极大地提高速度。似乎Divakars矢量化方法一直是最快的。 Minimum Time Maximum Time 谢谢。归功于Divakar和Jeremy。

修改 实施矢量化方法进一步测试显示出错误。对于示例数组

[[ 0.  0.  0.]
 [ 1.  0.  0.]
 [ 1.  1.  0.]
 [ 0.  1.  0.]
 [ 2.  0.  0.]
 [ 3.  0.  0.]
 [ 3.  1.  0.]
 [ 2.  1.  0.]]

向量化方法检索全0列表。 view1D是第二快的,所以我接受了。

<强> EDIT2: 迪瓦卡尔修复了这个bug。感谢