我有两个张量,张量a和张量b。
我想获取张量b中所有值的索引。
例如。
a = torch.Tensor([1,2,2,3,4,4,4,5])
b = torch.Tensor([1,2,4])
我想要1, 2, 4
在张量a中的索引。我可以通过以下代码来做到这一点。
a = torch.Tensor([1,2,2,3,4,4,4,5])
b = torch.Tensor([1,2,4])
mask = torch.zeros(a.shape).type(torch.bool)
print(mask)
for e in b:
mask = mask + (a == e)
print(mask)
没有for
怎么办?
答案 0 :(得分:1)
这是您想要的吗? :
np.in1d(a.numpy(), b.numpy())
将导致:
array([ True, True, True, False, True, True, True, False])
答案 1 :(得分:0)
如果只是不想使用for循环,则可以使用列表推导:
mask = [a[index] for index in b]
如果甚至不想使用“ for”一词,则始终可以将张量转换为numpy并使用numpy索引。
mask = torch.tensor(a.numpy()[b.numpy()])
更新
可能误解了您的问题。在这种情况下,我要说的最好的方法是通过列表理解。 (切片可能无法实现这一点。
mask = [index for index,value in enumerate(a) if value in b.tolist()]
这会遍历a中的每个元素,获取它们的索引和值,如果值在b内,则获取索引。