Python:从批处理中删除相等/相似的numpy数组的快速方法

时间:2018-08-15 07:42:24

标签: arrays numpy

我有一批与灰度图像相对应的28x28 numpy数组,我想擦除类似的数组(尽管我看起来很原始,但可以满足np.sum(np.abs(arrayA-arrayB))<50之类的东西像那样)。有没有人比循环更了解这一点。

使用np.unique删除重复的图像很容易,但是这对我来说很难。非常感谢

1 个答案:

答案 0 :(得分:2)

鉴于您的数据集不会变得太大,广播对于这种任务非常有用的工具。它允许您执行这些“外部”操作,例如将每个元素与其他每个元素的差取值,创建相似度矩阵:

要了解其工作原理,让我们看一维情况:

import numpy as np

data = numpy.random.rand(12) * 256
# Make sure we have some similar elements in `data`
data[0] = data[3]
data[7] = data[10]

diff = np.abs(data[None, :] - data[:, None])
diff.shape
# (12, 12)

要了解正在发生的事情,请看一下输出:

plt.imshow(diff)
plt.show()

现在,我们了解了如何利用广播,让我们将其适应您的3D情况:

data = np.random.rand(12, 28, 28) * 256
# Make sure we have some similar elements in `data`
data[0, ...] = data[3, ...]
data[7, ...] = data[10, ...]

diff = np.abs(data[None, ...] - data[:, None, ...])
diff.shape
# (12, 12, 28, 28)

如您所见,我们得到了一个张量,该张量包含每个像素与每个其他图块中相同像素的差异。要获得此差的总和,请对最后两个轴求和

diff = np.sum(diff, axis=(-1, -2))
diff.shape
# (12, 12)

再次,看看发生了什么事

plt.imshow(diff)
plt.show()

要查找重复的元素,我们可以使用您的条件:

diff = diff < 5

但是请注意,主对角线现在将是所有True值(每个图块都与自身进行了比较,差异显然是0)。因此,我们将其设置为False

np.fill_diagonal(diff, False)

就像健全性检查一样,现在让我们搜索True值:

np.where(diff)
# (array([ 0,  3,  7, 10]), array([ 3,  0, 10,  7]))

好的,这些值似乎合理。

现在要从diff数组中获取布尔列或行掩码,让我们按行搜索任何True值:

mask = np.any(diff, axis=0)
# array([ True, False, False,  True, False, False, False,  True, False, False,  True, False])

并使用此掩码过滤data。这将删除列0, 3, 7, 10

data = data[~mask, ...]
data.shape
# (8, 28, 28)

如果您想保留其中一个重复项,请仅在True的上三角形或下三角形中搜索diff,并保留其余的内容:

mask = np.any(np.triu(diff), axis=0)
# array([False, False, False,  True, False, False, False, False, False, False,  True, False])

这将删除列3, 10

data = data[~mask, ...]
data.shape
# (10, 28, 28)