Pytorch张量-如何通过特定张量获取索引

时间:2018-08-06 09:00:16

标签: python pytorch

我有张量

t = torch.tensor([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]])

和查询张量

q = torch.tensor([1, 0, 0, 0])

有没有办法像这样获得q的索引

indexes = t.index(q) # get back [0, 3]

在火炬中?

3 个答案:

答案 0 :(得分:1)

请尝试此操作,我没有在此PC上安装手电筒。

import torch

t = torch.tensor([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]])
q = torch.tensor([1, 0, 0, 0])

index = torch.nonzero(torch.sum((t == q), dim=1) == t.shape[1])

编辑说明:针对Shai提出的问题进行了编辑。

答案 1 :(得分:1)

怎么样

In [1]: torch.nonzero((t == q).sum(dim=1) == t.size(1))
Out[1]: 
tensor([[ 0],
        [ 3]])

比较t == qtq之间执行逐元素比较,因为您要查找整个行匹配项,因此需要{{1} },然后看看哪一行是.sum(dim=1)的完美匹配。


从v0.4.1开始,torch.all()支持== t.size(1)参数:

dim

答案 2 :(得分:0)

更多pytorch本机方法是:

torch.all(q.repeat((t.shape[1],1))==t, dim=1)