如何查找数组中的任何列是否具有重复值

时间:2017-02-02 21:02:24

标签: numpy

假设我有一个numpy矩阵A

A = array([[ 0.5,  0.5,  3.7],
           [ 3.8,  2.7,  3.7],
           [ 3.3,  1.0,  0.2]])

我想知道是否至少有两行ii'A[i, j]=A[i', j]列为j

在示例Ai=0i'=1 j=2中,答案为yes

我该怎么做?

我试过了:

def test(A, n):
    for j in range(n):
        i = 0
        while i < n:
            a = A[i, j]
            for s in range(i+1, n):
                if A[s, j] == a:
                    return True
            i += 1
    return False

有更快/更好的方法吗?

2 个答案:

答案 0 :(得分:3)

有多种方法可以检查重复项。我们的想法是尽可能少地使用Python代码中的循环来执行此操作。我将在这里介绍几种方式:

  1. 使用np.unique。你仍然需要遍历列,因为unique接受axis参数是没有意义的,因为每列可能有不同数量的唯一元素。虽然它仍然需要循环,unique允许您查找重复元素的位置和其他统计数据:

    def test(A):
        for i in A.shape[1]:
            if np.unique(A[:, i]).size < A.shape[0]:
                return True
        return False
    

    使用此方法,您基本上可以检查列中唯一元素的数量是否等于列的大小。如果没有,则有重复。

  2. 使用np.sortnp.diffnp.any。这是一个完全向量化的解决方案,不需要任何循环,因为您可以为每个函数指定一个轴:

    def test(A):
        return np.any(diff(np.sort(A, axis=0), axis=0) == 0)
    

    字面上读取&#34;如果逐列排序数组中的任何列方差异为零,则返回True&#34;。排序数组中的零差异意味着存在相同的元素。 axis=0 sortdiff分别对每列进行操作。

  3. 您永远不需要传入n,因为矩阵的大小是在属性shape中编码的。如果您需要查看矩阵的子集,只需使用索引传入子集。它不会复制数据,只返回具有所需尺寸的视图对象。

答案 1 :(得分:1)

没有numpy的解决方案如下所示:首先,用zip()交换列和行

zipped = zip(*A)

然后检查是否有任何 now row 有任何重复项。您可以通过将列表转换为集合来检查重复项,从而丢弃重复项,并检查长度。

has_duplicates = any(len(set(row)) != len(row) for row in zip(*A))

最有可能比纯粹的numpy解决方案更慢,也更省内存,但这可能有助于提高清晰度