How Pytorch Tensor get the index of specific value

时间:2017-12-18 06:13:16

标签: python pytorch

In python list, we can use list.index(somevalue). How can pytorch do this?
For example:

    a=[1,2,3]
    print(a.index(2))

Then, 1 will be output. How can a pytorch tensor do this without converting it to a python list?

8 个答案:

答案 0 :(得分:16)

我认为没有list.index()到pytorch函数的直接翻译。但是,您可以使用tensor==number然后使用nonzero()函数来获得类似的结果。例如:

t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero())

这段代码返回

  

1

     

[torch.LongTensor,大小为1x1]

答案 1 :(得分:4)

对于多维张量,您可以:

(tensor == target_value).nonzero(as_tuple=True)

生成的张量的形状为 number_of_matches x tensor_dimension。例如,假设 tensor 是一个 3 x 4 张量(这意味着维度为 2),结果将是一个二维张量,其中包含行中匹配项的索引。

tensor = torch.Tensor([[1, 2, 2, 7], [3, 1, 2, 4], [3, 1, 9, 4]])
(tensor == 2).nonzero(as_tuple=False)
>>> tensor([[0, 1],
        [0, 2],
        [1, 2]])

答案 2 :(得分:1)

可以通过如下转换为numpy来完成

import torch
x = torch.range(1,4)
print(x)
===> tensor([ 1.,  2.,  3.,  4.]) 
nx = x.numpy()
np.where(nx == 3)[0][0]
===> 2

答案 3 :(得分:0)

对于浮点张量,我使用它来获取张量中元素的索引。

print((torch.abs((torch.max(your_tensor).item()-your_tensor))<0.0001).nonzero())

这里我要获取浮点张量中max_value的索引,也可以像这样放置您的值以获取张量中任何元素的索引。

print((torch.abs((YOUR_VALUE-your_tensor))<0.0001).nonzero())

答案 4 :(得分:0)

用于在一维张量/数组中查找元素的索引 例子

mat=torch.tensor([1,8,5,3])

找到5的索引

five=5

numb_of_col=4
for o in range(numb_of_col):
   if mat[o]==five:
     print(torch.tensor([o]))

查找2d / 3d张量隐式为1d的元素索引 #ie example.view(元素数量)

示例

mat=torch.tensor([[1,2],[4,3])
#to find index of 2

five = 2
mat=mat.view(4)
numb_of_col = 4
for o in range(numb_of_col):
   if mat[o] == five:
     print(torch.tensor([o]))    

答案 5 :(得分:0)

基于其他人的回答:

t = torch.Tensor([1, 2, 3])
print((t==1).nonzero().item())

答案 6 :(得分:0)

已经给出的答案很好,但是当我尝试没有匹配时,它们无法处理。为此,请参阅:

def index(tensor: Tensor, value, ith_match:int =0) -> Tensor:
    """
    Returns generalized index (i.e. location/coordinate) of the first occurence of value
    in Tensor. For flat tensors (i.e. arrays/lists) it returns the indices of the occurrences
    of the value you are looking for. Otherwise, it returns the "index" as a coordinate.
    If there are multiple occurences then you need to choose which one you want with ith_index.
    e.g. ith_index=0 gives first occurence.

    Reference: https://stackoverflow.com/a/67175757/1601580
    :return:
    """
    # bool tensor of where value occurred
    places_where_value_occurs = (tensor == value)
    # get matches as a "coordinate list" where occurence happened
    matches = (tensor == value).nonzero()  # [number_of_matches, tensor_dimension]
    if matches.size(0) == 0:  # no matches
        return -1
    else:
        # get index/coordinate of the occurence you want (e.g. 1st occurence ith_match=0)
        index = matches[ith_match]
        return index

感谢这个伟大的答案:https://stackoverflow.com/a/67175757/1601580

答案 7 :(得分:-1)

    import torch
    x_data = variable(torch.Tensor([[1.0], [2.0], [3.0]]))
    print(x_data.data[0])
    >>tensor([1.])
相关问题