火炬在2个2D张量中找到匹配行的索引

时间:2020-01-12 14:52:01

标签: pytorch tensor

我有两个长度不同的2D张量,它们都是同一原始2d张量的不同子集,我想找到所有匹配的“行”
例如

A = [[1,2,3],[4,5,6],[7,8,9],[3,3,3]
B = [[1,2,3],[7,8,9],[4,4,4]]
torch.2dintersect(A,B) -> [0,2] (the indecies of A that B also have)

我只看到了numpy解决方案,它使用dtype作为字典,不适用于pytorch。


这是我在numpy

中的操作方法
arr1 = edge_index_dense.numpy().view(np.int32)
arr2 = edge_index2_dense.numpy().view(np.int32)
arr1_view = arr1.view([('', arr1.dtype)] * arr1.shape[1])
arr2_view = arr2.view([('', arr2.dtype)] * arr2.shape[1])
intersected = np.intersect1d(arr1_view, arr2_view, return_indices=True)

2 个答案:

答案 0 :(得分:2)

此答案是在OP用其他限制条件对该问题进行了相当多的更改之前发布的。

TL; DR ,您可以执行以下操作:

torch.where((A == B).all(dim=1))[0]

首先,假设您拥有:

import torch
A = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
B = torch.Tensor([[1,2,3],[4,4,4],[7,8,9]])

我们可以检查A == B是否返回:

>>> A == B
tensor([[ True,  True,  True],
        [ True, False, False],
        [ True,  True,  True]])

因此,我们想要的是:它们全部为True的行。为此,我们可以使用.all()操作并指定感兴趣的维度,在我们的情况下为1

>>> (A == B).all(dim=1)
tensor([ True, False,  True])

您真正想知道的是True的位置。为此,我们可以获得torch.where()函数的第一个输出:

>>> torch.where((A == B).all(dim=1))[0]
tensor([0, 2])

答案 1 :(得分:1)

如果A和B是2D张量,则以下代码将找到索引A[indices] == B。如果多个索引满足此条件,则返回找到的第一个索引。如果不是B的所有元素都存在于A中,则相应的索引将被忽略。

values, indices = torch.topk(((A.t() == B.unsqueeze(-1)).all(dim=1)).int(), 1, 1)
indices = indices[values!=0]
# indices = tensor([0, 2])