假设我有一个numpy矩阵A
A = array([[ 0.5, 0.5, 3.7],
[ 3.8, 2.7, 3.7],
[ 3.3, 1.0, 0.2]])
我想知道是否至少有两行i
和i'
,A[i, j]=A[i', j]
列为j
?
在示例A
,i=0
和i'=1
j=2
中,答案为yes
。
我该怎么做?
我试过了:
def test(A, n):
for j in range(n):
i = 0
while i < n:
a = A[i, j]
for s in range(i+1, n):
if A[s, j] == a:
return True
i += 1
return False
有更快/更好的方法吗?
答案 0 :(得分:3)
有多种方法可以检查重复项。我们的想法是尽可能少地使用Python代码中的循环来执行此操作。我将在这里介绍几种方式:
使用np.unique
。你仍然需要遍历列,因为unique
接受axis
参数是没有意义的,因为每列可能有不同数量的唯一元素。虽然它仍然需要循环,unique
允许您查找重复元素的位置和其他统计数据:
def test(A):
for i in A.shape[1]:
if np.unique(A[:, i]).size < A.shape[0]:
return True
return False
使用此方法,您基本上可以检查列中唯一元素的数量是否等于列的大小。如果没有,则有重复。
使用np.sort
,np.diff
和np.any
。这是一个完全向量化的解决方案,不需要任何循环,因为您可以为每个函数指定一个轴:
def test(A):
return np.any(diff(np.sort(A, axis=0), axis=0) == 0)
字面上读取&#34;如果逐列排序数组中的任何列方差异为零,则返回True&#34;。排序数组中的零差异意味着存在相同的元素。 axis=0
sort
和diff
分别对每列进行操作。
您永远不需要传入n
,因为矩阵的大小是在属性shape
中编码的。如果您需要查看矩阵的子集,只需使用索引传入子集。它不会复制数据,只返回具有所需尺寸的视图对象。
答案 1 :(得分:1)
没有numpy的解决方案如下所示:首先,用zip()
交换列和行
zipped = zip(*A)
然后检查是否有任何 now row 有任何重复项。您可以通过将列表转换为集合来检查重复项,从而丢弃重复项,并检查长度。
has_duplicates = any(len(set(row)) != len(row) for row in zip(*A))
最有可能比纯粹的numpy解决方案更慢,也更省内存,但这可能有助于提高清晰度