合并包含相同值的数组

时间:2018-09-09 03:58:32

标签: python numpy

我需要从相关系数矩阵中获得一个高相关组,保留其中一个并排除另一个。但是我不知道如何优雅而有效地做到这一点。

这是一个类似的答案,但希望它将使用类似矢量的矩阵来完成。: Merge arrays if they contain one or more of the same value

例如:

a = np.array([[1,0,0,0,0,1],
              [0,1,0,1,0,0],
              [0,0,1,0,1,1],
              [0,1,0,1,0,0],
              [0,0,1,0,1,0],              
              [1,0,1,0,0,1]])

对角线:

(0,0),(1,1),(2,2)...(5,5)

其他:

(0,5),(1,3),(2,4),(2,5)

这三对因为彼此包含而合并成一组:

(0,2,4,5) = (0,5),(2,4),(2,5)

所以最终我需要输出: (我将使用结果对其他数据建立索引,因此决定在每个组中保留最大的索引)

out = [(0,2,4,5),(1,3)]

我认为最简单的方法是采用嵌套循环并遍历所有元素多次。我想有一个更简洁高效的实现方式,谢谢

这是一个循环实现,很抱歉很难看到:

a = np.array([[1,0,0,0,0,1],
              [0,1,0,1,0,0],
              [0,0,1,0,1,1],
              [0,1,0,1,0,0],
              [0,0,1,0,1,0],              
              [1,0,1,0,0,1]])

a[np.tril_indices(6, -1)]= 0     
a[np.diag_indices(6)]    = 0     
g = list(np.c_[np.where(a)])

p = {}; index = 1
while len(g)>0:
    x = g.pop(0)
    if not p:
        p[index] = list(x)
        for i,l in enumerate(g):
            if np.in1d(l,x[0]).any()|np.in1d(l,x[1]).any():
                n = list(g.pop(i))
                p[index].extend(n)
    else:
        T = False
        for key,v in p.items():
            if np.in1d(v,x[0]).any()|np.in1d(v,x[1]).any():
                v.extend(list(x))
                T = True
        if T==False:
            index += 1; p[index] = list(x)
            for i,l in enumerate(g):
                if np.in1d(l,x[0]).any()|np.in1d(l,x[1]).any():
                    n = list(g.pop(i))
                    p[index].extend(n)

for key,v in p.items():
    print key,np.unique(v)

退出:

1 [0 2 4 5]
2 [1 3]

1 个答案:

答案 0 :(得分:1)

可以使用this answer解决合并/合并具有极值的对的中心问题。

因此,上面的代码可能会被重写为:

a = np.array([[1,0,0,0,0,1],
              [0,1,0,1,0,0],
              [0,0,1,0,1,1],
              [0,1,0,1,0,0],
              [0,0,1,0,1,0],              
              [1,0,1,0,0,1]])

a[np.tril_indices(6, -1)]= 0     
a[np.diag_indices(6)]    = 0     
g = np.c_[np.where(a)].tolist()

def consolidate(items):
    items = [set(item.copy()) for item in items]
    for i, x in enumerate(items):
        for j, y in enumerate(items[i + 1:]):
            if x & y:
                items[i + j + 1] = x | y
                items[i] = None
    return [sorted(x) for x in items if x]

p = {i + 1: x for i, x in enumerate(sorted(consolidate(g)))}