查找相等numpy 2D行的索引

时间:2016-05-02 11:58:58

标签: python-3.x numpy

我有一个2D numpy数组的Python列表(都具有相同的形状),我想提取相同数组的索引。我想出了这个:

a = np.array([[1, 2], [3, 4]])
b = np.array([[1, 2], [3, 4]])
c = np.array([[3, 4], [1, 2]])
d = np.array([[3, 4], [1, 2]])
e = np.array([[3, 4], [1, 2]])
f = np.array([[1, 2], [3, 4]])
g = np.array([[9, 9], [3, 4]])

li = [a, b, c, d, e, f, g]

indexes = list(range(len(li)))
equals = []
for i, a_i in enumerate(indexes):
    a_equals = []
    for j, b_i in enumerate(indexes[i+1:]):
        if np.array_equal(li[a_i], li[b_i]):
            del indexes[j]
            a_equals.append(b_i)
    if a_equals:
        equals.append((a_i, *a_equals))

print(equals)
# [(0, 1, 5), (2, 3, 4)]

它可以工作(你可以假设没有2D数组是空的)但是解决方案很笨拙并且可能很慢。有没有办法用Numpy更优雅地做到这一点?

3 个答案:

答案 0 :(得分:1)

鉴于列表中的输入数组具有相同的形状,您可以将数组列表连接到单个2D数组中,每行代表输入列表的每个元素。这使得进一步的计算更容易并且便于矢量化操作。实现看起来像这样 -

# Concatenate all elements into a 2D array
all_arr = np.concatenate(li).reshape(-1,li[0].size)

# Reduce each row with IDs such that each they represent indexing tuple 
ids = np.ravel_multi_index(all_arr.T,all_arr.max(0)+1)

# Tag each such IDs based on uniqueness against other IDs
_,unqids,C = np.unique(ids,return_inverse=True,return_counts=True)

# Sort the unique IDs and split into groups for final output
sidx = unqids.argsort()

# Mask corresponding to unqids that has ID counts > 1
mask = np.in1d(unqids,np.where(C>1)[0])

# Split masked sorted indices at places corresponding to cumsum-ed counts
out = np.split(sidx[mask[sidx]],C[C>1].cumsum())[:-1]

注意:如果连接的输入数组all_arr中有大量列,您可能希望使用np.cumprod手动获取索引ids ,就像这样 -

ids = all_arr.dot(np.append(1,(all_arr.max(0)+1)[::-1][:-1].cumprod())[::-1])

答案 1 :(得分:0)

也许你可以尝试itertools

import itertools
from collections import defaultdict

equals=defaultdict(list)
visited=[]
for a, b in itertools.combinations(enumerate(li), 2):
  if not b[0] in visited and np.array_equal(a[1], b[1]) :
    equals[a[0]].append(b[0])
    visited += (a[0],b[0])

print equals
# defaultdict(<type 'list'>, {0: [1, 5], 2: [3, 4]})

答案 2 :(得分:0)

使用numpy_indexed包可以优雅地解决这个问题(免责声明:我是它的作者):

import numpy_indexed as npi
print(npi.group_by(npi.as_index(li).inverse).split(np.arange(len(li))))

我怀疑发现这些指数很可能不是你的最终目标,而且如果你使用numpy_indexed稍微玩一下,你可能会发现存在更直接的终点目标路径。

实际上,删除单计数索引可能最好留作后处理步骤,尽管您也可以使用npi.multiplicity&gt; 1作为预处理步骤