Numpy intersect1d with array with matrix as elements

时间:2017-01-01 15:40:03

标签: python arrays numpy

我有两个数组,一个是形状(200000, 28, 28),另一个是形状(10000, 28, 28),所以实际上有两个数组,矩阵作为元素。 现在我想计算并获得两个数组中重叠的所有元素(格式为(N, 28, 28))。使用正常的for循环是慢的方法,所以我用numpys intersect1d方法尝试它,但我不知道如何将它应用于这种类型的数组。

1 个答案:

答案 0 :(得分:4)

使用this question about unique rows

中的方法
def intersect_along_first_axis(a, b):
    # check that casting to void will create equal size elements
    assert a.shape[1:] == b.shape[1:]
    assert a.dtype == b.dtype

    # compute dtypes
    void_dt = np.dtype((np.void, a.dtype.itemsize * np.prod(a.shape[1:])))
    orig_dt = np.dtype((a.dtype, a.shape[1:]))

    # convert to 1d void arrays
    a = np.ascontiguousarray(a)
    b = np.ascontiguousarray(b)
    a_void = a.reshape(a.shape[0], -1).view(void_dt)
    b_void = b.reshape(b.shape[0], -1).view(void_dt)

    # intersect, then convert back
    return np.intersect1d(b_void, a_void).view(orig_dt)

请注意,使用void使用浮点数是不安全的,因为这会导致-0不等于0