如何散列numpy数组以检查重复项

时间:2016-09-25 00:18:31

标签: python arrays numpy hash

我已经搜索了一些教程等来帮助解决这个问题,但似乎找不到任何东西。

我有两个n维numpy数组列表(某些图像的3D数组形式),我想检查每个列表中的重叠图像。让我们说列表a是训练集,列表b是验证集。 一种解决方案就是使用嵌套循环并使用np.array(a[i], b[j])检查每对数组是否相等,但这很慢(每个列表中有大约200,000个numpy数组)并且坦率地说非常恶心。

我认为实现这一目标的更优雅的方法是散列每个列表中的每个numpy数组,然后使用这些哈希表比较每个条目。

首先,这个解决方案是否正确,其次,我将如何实现这一目标? 下面是一些数据的例子。

train_dataset[:3]
array([[[-0.5       , -0.49607843, -0.5       , ..., -0.5       ,
         -0.49215686, -0.5       ],
        [-0.49607843, -0.47647059, -0.5       , ..., -0.5       ,
         -0.47254902, -0.49607843],
        [-0.49607843, -0.49607843, -0.5       , ..., -0.5       ,
         -0.49607843, -0.49607843],
        ..., 
        [-0.49607843, -0.49215686, -0.5       , ..., -0.5       ,
         -0.49215686, -0.49607843],
        [-0.49607843, -0.47647059, -0.5       , ..., -0.5       ,
         -0.47254902, -0.49607843],
        [-0.5       , -0.49607843, -0.5       , ..., -0.5       ,
         -0.49607843, -0.5       ]],

       [[-0.5       , -0.5       , -0.5       , ...,  0.48823529,
          0.5       ,  0.1509804 ],
        [-0.5       , -0.5       , -0.5       , ...,  0.48431373,
          0.14705883, -0.32745099],
        [-0.5       , -0.5       , -0.5       , ..., -0.32745099,
         -0.5       , -0.49607843],
        ..., 
        [-0.5       , -0.44901961,  0.1509804 , ..., -0.5       ,
         -0.5       , -0.5       ],
        [-0.49607843, -0.49607843, -0.49215686, ..., -0.5       ,
         -0.5       , -0.5       ],
        [-0.5       , -0.49607843, -0.48823529, ..., -0.5       ,
         -0.5       , -0.5       ]],

       [[-0.5       , -0.5       , -0.5       , ..., -0.5       ,
         -0.5       , -0.5       ],
        [-0.5       , -0.5       , -0.5       , ..., -0.5       ,
         -0.5       , -0.5       ],
        [-0.5       , -0.5       , -0.49607843, ..., -0.5       ,
         -0.5       , -0.5       ],
        ..., 
        [-0.5       , -0.5       , -0.5       , ..., -0.48823529,
         -0.5       , -0.5       ],
        [-0.5       , -0.5       , -0.5       , ..., -0.5       ,
         -0.5       , -0.5       ],
        [-0.5       , -0.5       , -0.5       , ..., -0.5       ,
         -0.5       , -0.5       ]]], dtype=float32)

我提前感谢你的帮助。

3 个答案:

答案 0 :(得分:0)

您可以使用numpy的intersect1d(一维集相交)函数在数组之间找到重复项。

duplicate_images = np.intersect1d(train_dataset, test_dataset) 

我使用tensorflow tutorials之一的训练和测试集(分别为55000和10000阵列)来计时,我猜测它与您的数据类似。使用intersect1d,在我的机器上完成大约需要2.4秒(参数assume_unique=True只用了1.3秒)。像你描述的成对比较花了几分钟。

修改

这个答案(上面)没有比较每个"图像"数组,正如@ mbhall88在注释中指出的那样,它比较了数组中的元素,而不是数组本身。为了确保比较数组,你仍然可以使用intersect1d,但你必须首先解释dtypes,如here所述。但是,该答案中的示例涉及2d数组,并且由于您正在使用3d数组,因此您应首先展平后两个维度。您应该可以执行以下操作:

def intersect3d(A,B, assume_unique=False):
    # get the original shape of your arrays
    a1d, a2d, a3d = A.shape
    # flatten the 2nd and 3rd dimensions in your arrays
    A = A.reshape((a1d,a2d*a3d))
    B = B.reshape((len(B),a2d*a3d))
    # define a structured dtype so you can treat your arrays as single "element"
    dtype=(', '.join([str(A.dtype)]*ncols))
    # find the duplicate elements
    C = np.intersect1d(A.view(dtype), B.view(dtype), assume_unique=assume_unique)
    # reshape the result and return
    return C.view(A.dtype).reshape(-1, ncols).reshape((len(C),a2d,a3d))

答案 1 :(得分:0)

numpy_indexed包(免责声明:我是它的作者)为此提供了有效的单行程序:

import numpy_indexed as npi
duplicate_images = npi.intersection(train_dataset, test_dataset) 

此外,您可能会发现很多相关的功能。

答案 2 :(得分:0)

提出某事并不是那么困难:

from collections import defaultdict
import numpy as np

def arrayhash(arr):
    u = arr.view('u' + str(arr.itemsize))
    return np.bitwise_xor.reduce(u.ravel())

def do_the_thing(a, b, hashfunc=arrayhash):
    table = defaultdict(list)
    for i, a_i in enumerate(a):
        table[hashfunc(a_i)].append(i)

    indices = []
    for j, b_j in enumerate(b):
        candidates = table[hashfunc(b_j)]
        for i in candidates:
            if np.array_equiv(a[i], b_j):
                indices.append((i,j))

    return indices

但请注意:

  • 检查浮点相等通常是一个坏主意,因为精度和舍入误差有限。着名的例子:

    >>> 0.1 + 0.2 == 0.3
    False
    
  • NaN不等于自己:

    >>> np.nan == np.nan
    False
    
  • 上面的简单哈希函数关于浮点数的位表示,但是在存在负零和信令NaN的情况下这是有问题的。

另请参阅此问题中的讨论:Good way to hash a float vector?