对于给定条件,获取2D张量A中值的索引,使用那些索引3D张量B中的值

时间:2019-09-27 14:31:36

标签: python multidimensional-array pytorch tensor tensor-indexing

对于给定的2D张量,我想检索所有值为CellOverrideLabelAccumulator的索引。我希望能够简单地使用1,它将返回torch.nonzero(a == 1).squeeze()。但是,相反,tensor([1, 3, 2])返回2D张量(没关系),每行有两个值(这不是我期望的值)。然后,应使用返回的索引来索引3D张量的第二维(索引1),再次返回2D张量。

torch.nonzero(a == 1)

显然,这不起作用,导致aRunTimeError:

  

RuntimeError:无效参数4:索引张量必须具有相同的   尺寸作为输入张量   C:\ w \ 1 \ s \ windows \ pytorch \ aten \ src \ TH / generic / THTensorEvenMoreMath.cpp:453

看来import torch a = torch.Tensor([[12, 1, 0, 0], [4, 9, 21, 1], [10, 2, 1, 0]]) b = torch.rand(3, 4, 8) print('a_size', a.size()) # a_size torch.Size([3, 4]) print('b_size', b.size()) # b_size torch.Size([3, 4, 8]) idxs = torch.nonzero(a == 1) print('idxs_size', idxs.size()) # idxs_size torch.Size([3, 2]) print(b.gather(1, idxs)) 并不是我所期望的,我也无法按照我的想法使用它。 idxs

idxs

但是通读documentation我不明白为什么我还要在结果张量中重新获得行索引。现在,我知道我可以通过切片tensor([[0, 1], [1, 3], [2, 2]]) 来获取正确的idx,但是仍然不能将这些值用作3D张量的索引,因为会产生与以前相同的错误。是否可以使用一维张量索引来选择给定维度上的项目?

4 个答案:

答案 0 :(得分:1)

您可以简单地对其进行切片并将其作为索引传递,如下所示:

In [193]: idxs = torch.nonzero(a == 1)     
In [194]: c = b[idxs[:, 0], idxs[:, 1]]  

In [195]: c   
Out[195]: 
tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
        [0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
        [0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])

或者,更简单和我更喜欢的方法是只使用torch.where(),然后直接在张量b中建立索引,如下所示:

In [196]: b[torch.where(a == 1)]  
Out[196]: 
tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
        [0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
        [0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])

关于以上使用torch.where()的方法的更多解释:它基于advanced indexing的概念起作用。也就是说,当我们使用序列对象元组(例如张量元组,列表元组,元组元组等)对张量进行索引时。

# some input tensor
In [207]: a  
Out[207]: 
tensor([[12.,  1.,  0.,  0.],
        [ 4.,  9., 21.,  1.],
        [10.,  2.,  1.,  0.]])

对于基本切片,我们需要一个整数索引元组:

   In [212]: a[(1, 2)] 
   Out[212]: tensor(21.)

要使用高级索引实现相同目的,我们将需要一个序列对象元组:

# adv. indexing using a tuple of lists
In [213]: a[([1,], [2,])] 
Out[213]: tensor([21.])

# adv. indexing using a tuple of tuples
In [215]: a[((1,), (2,))]  
Out[215]: tensor([21.])

# adv. indexing using a tuple of tensors
In [214]: a[(torch.tensor([1,]), torch.tensor([2,]))] 
Out[214]: tensor([21.])

返回的张量的尺寸总是比输入张量的尺寸小一维。

答案 1 :(得分:0)

import torch

a = torch.Tensor([[12, 1, 0, 0],
                  [4, 9, 21, 1],
                  [10, 2, 1, 0]])

b = torch.rand(3, 4, 8)

print('a_size', a.size())
# a_size torch.Size([3, 4])
print('b_size', b.size())
# b_size torch.Size([3, 4, 8])

#idxs = torch.nonzero(a == 1, as_tuple=True)
idxs = torch.nonzero(a == 1)
#print('idxs_size', idxs.size())

print(torch.index_select(b,1,idxs[:,1]))

答案 2 :(得分:0)

假设b的三个维度为batch_size x sequence_length x features(b x s x个专长),则可以实现以下预期结果。

import torch

a = torch.Tensor([[12, 1, 0, 0],
                  [4, 9, 21, 1],
                  [10, 2, 1, 0]])

b = torch.rand(3, 4, 8)
print(b.size())
# b x s x feats
idxs = torch.nonzero(a == 1)[:, 1]
print(idxs.size())
# b
c = b[torch.arange(b.size(0)), idxs]
print(c.size())
# b x feats

答案 3 :(得分:0)

作为@ kmario23解决方案的补充,您仍然可以实现类似的结果

b[torch.nonzero(a==1,as_tuple=True)]