PyTorch:带有二维张量的索引高维张量

时间:2020-06-07 07:00:15

标签: numpy indexing pytorch

假设我有以下张量:

N = 2
k = 3
d = 2

L = torch.arange(N * k * d * d).view(N, k, d, d)
L
tensor([[[[ 0,  1],
          [ 2,  3]],

         [[ 4,  5],
          [ 6,  7]],

         [[ 8,  9],
          [10, 11]]],


        [[[12, 13],
          [14, 15]],

         [[16, 17],
          [18, 19]],

         [[20, 21],
          [22, 23]]]])


index = torch.Tensor([0,1,0,0]).view(N,-1)
index
tensor([[0., 1.],
        [0., 0.]])

我现在想使用索引张量来挑选第二维上的对应矩阵,即我想得到类似的东西:

tensor([[[[ 0,  1],
          [ 2,  3]],

         [[ 4,  5],
          [ 6,  7]]],


        [[[12, 13],
          [14, 15]],

         [[[[12, 13],
          [14, 15]]])

有什么想法可以实现这一目标吗? 非常感谢!

1 个答案:

答案 0 :(得分:1)

张量可以用在不同维度(张量的元组)上指定的多个张量索引,其中每个张量的 th-th 元素组合在一起以创建索引元组,即{{1} }导致索引data[indices_dim0, indices_dim1]data[indices_dim0[0], indices_dim1[0]]等。它们必须具有相同的长度data[indices_dim0[1], indices_dim1[1]]

让我们使用len(indices_dim0) == len(indices_dim1)的平面版本(在应用视图之前)。每个元素都需要与适当的批次索引匹配,该索引将为index。另外,[0, 0, 1, 1]还必须具有类型index,因为不能将浮点数用作索引。应该优先考虑使用torch.tensor与现有数据创建张量,因为torch.long是默认张量类型(torch.Tensor)的别名,而torch.FloatTensor自动使用表示以下内容的数据类型给定值,但还支持torch.tensor参数来手动设置类型,并且通常更具通用性。

dtype

索引不仅限于一维张量,而且它们都需要具有相同的大小,并且每个元素都用作一个索引,例如,对于2D张量,索引发生为# Type torch.long is inferred index = torch.tensor([0, 1, 0, 0]) # Same, but explicitly setting the type index = torch.tensor([0, 1, 0, 0], dtype=torch.long) batch_index = torch.tensor([0, 0, 1, 1]) L[batch_index, index] # => tensor([[[ 0, 1], # [ 2, 3]], # # [[ 4, 5], # [ 6, 7]], # # [[12, 13], # [14, 15]], # # [[12, 13], # [14, 15]]]) 使用2D张量,无需手动进行操作即可创建批次索引。

data[indices_dim0[i][j], indices_dim1[i][j]]