在python ndarray中查找重复行的索引

时间:2016-07-27 10:39:47

标签: python numpy multidimensional-array data-science

我编写了for循环来枚举包含n行28x28像素值的多维ndarray。

我正在寻找重复的每一行的索引以及没有重复的重复索引。

我发现这段代码here(感谢unutbu)并将其修改为读取ndarray,它在70%的时间内都有效,但有30%的时间它将错误的图像识别为重复。

如何改进检测正确的行?

def overlap_same(arr):
seen = []
dups = collections.defaultdict(list)
for i, item in enumerate(arr):
    for j, orig in enumerate(seen):
        if np.array_equal(item, orig):
            dups[j].append(i)
            break
    else:
        seen.append(item)
return dups

e.g。 return overlap_same(train)返回:

defaultdict(<type 'list'>, {34: [1388], 35: [1815], 583: [3045], 3208:
[4426], 626: [824], 507: [4438], 188: [338, 431, 540, 757, 765, 806,
808, 834, 882, 1515, 1539, 1715, 1725, 1789, 1841, 2038, 2081, 2165,
2170, 2300, 2455, 2683, 2733, 2957, 3290, 3293, 3311, 3373, 3446, 3542,
3565, 3890, 4110, 4197, 4206, 4364, 4371, 4734, 4851]})

在matplotlib上绘制正确案例的一些样本给出:

fig = plt.figure()
a=fig.add_subplot(1,2,1)
plt.imshow(train[35])
a.set_title('train[35]')
a=fig.add_subplot(1,2,2)
plt.imshow(train[1815])
a.set_title('train[1815]')
plt.show

train data 35 vs 1815

这是正确的

然而:

fig = plt.figure()
a=fig.add_subplot(1,2,1)
plt.imshow(train[3208])
a.set_title('train[3208]')
a=fig.add_subplot(1,2,2)
plt.imshow(train[4426])
a.set_title('train[4426]')
plt.show

enter image description here

不正确,因为它们不匹配

样本数据(火车[:3])

array([[[-0.5       , -0.5       , -0.5       , ...,  0.48823529,
      0.5       ,  0.17058824],
    [-0.5       , -0.5       , -0.5       , ...,  0.48823529,
      0.5       , -0.0372549 ],
    [-0.5       , -0.5       , -0.5       , ...,  0.5       ,
      0.47647059, -0.24509804],
    ..., 
    [-0.49215686,  0.34705883,  0.5       , ..., -0.5       ,
     -0.5       , -0.5       ],
    [-0.31176472,  0.44901961,  0.5       , ..., -0.5       ,
     -0.5       , -0.5       ],
    [-0.11176471,  0.5       ,  0.49215686, ..., -0.5       ,
     -0.5       , -0.5       ]],

   [[-0.24509804,  0.2764706 ,  0.5       , ...,  0.5       ,
      0.25294119, -0.36666667],
    [-0.5       , -0.47254902, -0.02941176, ...,  0.20196079,
     -0.46862745, -0.5       ],
    [-0.49215686, -0.5       , -0.5       , ..., -0.47647059,
     -0.5       , -0.49607843],
    ..., 
    [-0.49215686, -0.49607843, -0.5       , ..., -0.5       ,
     -0.5       , -0.49215686],
    [-0.5       , -0.5       , -0.26862746, ...,  0.13137256,
     -0.46470588, -0.5       ],
    [-0.30000001,  0.11960784,  0.48823529, ...,  0.5       ,
      0.28431374, -0.24117647]],

   [[-0.5       , -0.5       , -0.5       , ..., -0.5       ,
     -0.5       , -0.5       ],
    [-0.5       , -0.5       , -0.5       , ..., -0.5       ,
     -0.5       , -0.5       ],
    [-0.5       , -0.5       , -0.5       , ..., -0.5       ,
     -0.5       , -0.5       ],
    ..., 
    [-0.5       , -0.5       , -0.5       , ...,  0.48431373,
      0.5       ,  0.31568629],
    [-0.5       , -0.49215686, -0.5       , ...,  0.49215686,
      0.5       ,  0.04901961],
    [-0.5       , -0.5       , -0.5       , ...,  0.04117647,
     -0.17450981, -0.45686275]]], dtype=float32)

1 个答案:

答案 0 :(得分:2)

numpy_indexed包具有很多功能,可以有效地解决这些类型的问题。

例如,(不像numpy的内置独特),这将找到你独特的图像:

import numpy_indexed as npi
unique_training_images = npi.unique(train)

或者,如果要查找每个唯一组的所有索引,可以使用:

indices = npi.group_by(train).split(np.arange(len(train)))

请注意,这些函数没有像原始帖子那样具有二次时间复杂度,并且完全向量化,因此很可能效率更高。此外,与pandas不同,它没有首选的数据格式,并且完全具有nd-array功能,因此对具有shape [n_images,28,28]的数组进行操作只是“有效”。