Pytorch Tensor如何获取元素索引?

时间:2019-09-14 08:27:32

标签: python pytorch torch tensor

我有两个名为 x list 的张量,其定义如下:

df = df[~df.apply(' '.join, 1).str.contains('|'.join(keywords), case=False)]
print (df)
    Brand     ID Description
1  iPhone  DF747     battery

现在,我想从列表获取元素 x 的索引。预期的输出是一个整数:

x = torch.tensor(3)
list = torch.tensor([1,2,3,4,5])

我如何轻松地完成工作?

1 个答案:

答案 0 :(得分:2)

import torch

x = torch.tensor(3)

list = torch.tensor([1,2,3,4,5])
idx = (list == x).nonzero().flatten()
print (idx.tolist()) # [2]

list = torch.tensor([1,2,3,3,5])
idx = (list == x).nonzero().flatten()
print (idx.tolist()) # [2, 3]