我有一批与灰度图像相对应的28x28 numpy数组,我想擦除类似的数组(尽管我看起来很原始,但可以满足np.sum(np.abs(arrayA-arrayB))<50之类的东西像那样)。有没有人比循环更了解这一点。
使用np.unique删除重复的图像很容易,但是这对我来说很难。非常感谢
答案 0 :(得分:2)
鉴于您的数据集不会变得太大,广播对于这种任务非常有用的工具。它允许您执行这些“外部”操作,例如将每个元素与其他每个元素的差取值,创建相似度矩阵:
要了解其工作原理,让我们看一维情况:
import numpy as np
data = numpy.random.rand(12) * 256
# Make sure we have some similar elements in `data`
data[0] = data[3]
data[7] = data[10]
diff = np.abs(data[None, :] - data[:, None])
diff.shape
# (12, 12)
要了解正在发生的事情,请看一下输出:
plt.imshow(diff)
plt.show()
现在,我们了解了如何利用广播,让我们将其适应您的3D情况:
data = np.random.rand(12, 28, 28) * 256
# Make sure we have some similar elements in `data`
data[0, ...] = data[3, ...]
data[7, ...] = data[10, ...]
diff = np.abs(data[None, ...] - data[:, None, ...])
diff.shape
# (12, 12, 28, 28)
如您所见,我们得到了一个张量,该张量包含每个像素与每个其他图块中相同像素的差异。要获得此差的总和,请对最后两个轴求和
diff = np.sum(diff, axis=(-1, -2))
diff.shape
# (12, 12)
再次,看看发生了什么事
plt.imshow(diff)
plt.show()
要查找重复的元素,我们可以使用您的条件:
diff = diff < 5
但是请注意,主对角线现在将是所有True
值(每个图块都与自身进行了比较,差异显然是0
)。因此,我们将其设置为False
:
np.fill_diagonal(diff, False)
就像健全性检查一样,现在让我们搜索True
值:
np.where(diff)
# (array([ 0, 3, 7, 10]), array([ 3, 0, 10, 7]))
好的,这些值似乎合理。
现在要从diff
数组中获取布尔列或行掩码,让我们按行搜索任何True
值:
mask = np.any(diff, axis=0)
# array([ True, False, False, True, False, False, False, True, False, False, True, False])
并使用此掩码过滤data
。这将删除列0, 3, 7, 10
data = data[~mask, ...]
data.shape
# (8, 28, 28)
如果您想保留其中一个重复项,请仅在True
的上三角形或下三角形中搜索diff
,并保留其余的内容:
mask = np.any(np.triu(diff), axis=0)
# array([False, False, False, True, False, False, False, False, False, False, True, False])
这将删除列3, 10
data = data[~mask, ...]
data.shape
# (10, 28, 28)