如何在Numpy中找到矩阵的矩阵公共成员

时间:2020-06-05 08:39:13

标签: numpy

我有一个二维矩阵A和一个向量B。我想找到A中元素的所有行索引,这些元素也包含在B中。

A = np.array([[1,9,5], [8,4,9], [4,9,3], [6,7,5]], dtype=int)
B = np.array([2, 4, 8, 10, 12, 18], dtype=int)

我当前的解决方案是一次仅将A与B的一个元素进行比较,但这太慢了:

res = np.array([], dtype=int)
for i in range(B.shape[0]):
    cres, _ = (B[i] == A).nonzero()
    degElem = np.append(res, cres)
res = np.unique(res)

以下Matlab语句可以解决我的问题:

find(any(reshape(any(reshape(A, prod(size(A)), 1) == B, 2),size(A, 1),size(A, 2)), 2))

但是,在Numpy中比较行和列向量不会像在Matlab中那样创建布尔交集矩阵。 在Numpy中有正确的方法吗?

1 个答案:

答案 0 :(得分:1)

我们可以使用np.isin遮罩。

要获取所有行号,它应该是-

np.where(np.isin(A,B).T)[1]

如果您需要根据每个元素的出现将它们分开-

[np.flatnonzero(i) for i in np.isin(A,B).T if i.any()]

发布的MATLAB代码似乎正在执行broadcasting。因此,等效的是-

np.where(B[:,None,None]==A)[1]