pytorch张量获取具有特定值的元素的索引?

时间:2020-10-11 05:16:46

标签: python numpy tensorflow pytorch

我有两个张量,张量a和张量b。

我想获取张量b中所有值的索引。

例如。

a = torch.Tensor([1,2,2,3,4,4,4,5])
b = torch.Tensor([1,2,4])

我想要1, 2, 4在张量a中的索引。我可以通过以下代码来做到这一点。

a = torch.Tensor([1,2,2,3,4,4,4,5])
b = torch.Tensor([1,2,4])
mask = torch.zeros(a.shape).type(torch.bool)
print(mask)
for e in b:
    mask = mask + (a == e)
    print(mask)

没有for怎么办?

2 个答案:

答案 0 :(得分:1)

这是您想要的吗? :

np.in1d(a.numpy(), b.numpy())

将导致:

array([ True,  True,  True, False,  True,  True,  True, False])

答案 1 :(得分:0)

如果只是不想使用for循环,则可以使用列表推导:

mask = [a[index] for index in b]

如果甚至不想使用“ for”一词,则始终可以将张量转换为numpy并使用numpy索引。

mask = torch.tensor(a.numpy()[b.numpy()])

更新

可能误解了您的问题。在这种情况下,我要说的最好的方法是通过列表理解。 (切片可能无法实现这一点。

mask = [index for index,value in enumerate(a) if value in b.tolist()] 

这会遍历a中的每个元素,获取它们的索引和值,如果值在b内,则获取索引。