使用表示索引的一维长张量选择 3D Pytorch Tensor 的特定索引

时间:2021-04-01 08:07:56

标签: python pytorch tensor

所以我有一个 M x B x C 的张量,其中 M 是模型的数量,B 是批次,C 是类别,每个单元格是给定模型和批次的类别概率。然后我有一个正确答案的张量,它只是一个大小为 B 的一维,我们称之为“t”。如何使用大小为 B 的 1D 只返回 M x B x 1,其中返回的张量只是正确类的值?假设 M x B x C 张量被称为“blah”,我试过

blah[:, :, C]

for i in range(M):
    blah[i, :, C]

blah[:, C, :]

前 2 个只返回每个切片的第 3 个维度中索引 t 的值。最后一个返回第二维中 t 个索引处的值。我该怎么做?

3 个答案:

答案 0 :(得分:1)

我们可以通过combining advanced and basic indexing

得到想要的结果
import torch

# shape [2, 3, 4]
blah = torch.tensor([
    [[ 0,  1,  2,  3],
     [ 4,  5,  6,  7],
     [ 8,  9, 10, 11]],
    [[12, 13, 14, 15],
     [16, 17, 18, 19],
     [20, 21, 22, 23]]])

# shape [3]
t = torch.tensor([2, 1, 0])
b = torch.arange(blah.shape[1]).type_as(t)

# shape [2, 3, 1]
result = blah[:, b, t].unsqueeze(-1)

结果

>>> result
tensor([[[ 2],
         [ 5],
         [ 8]],
        [[14],
         [17],
         [20]]])

答案 1 :(得分:1)

您只需要通过:

  • 您的索引作为第三个切片
  • range(B) 作为第二个切片
    (即每个第 3 个暗淡索引对应于第 2 个暗淡中的哪个元素)
blah[:,range(B),t]

答案 2 :(得分:0)

这是一种方法:

假设 a 是您的 M x B x C 形张量。我在下面取一些有代表性的值,

>>> M = 3
>>> B = 5
>>> C = 4
>>> a = torch.rand(M, B, C)
>>> a
tensor([[[0.6222, 0.6703, 0.0057, 0.3210],
         [0.6251, 0.3286, 0.8451, 0.5978],
         [0.0808, 0.8408, 0.3795, 0.4872],
         [0.8589, 0.8891, 0.8033, 0.8906],
         [0.5620, 0.5275, 0.4272, 0.2286]],

        [[0.2419, 0.0179, 0.2052, 0.6859],
         [0.1868, 0.7766, 0.3648, 0.9697],
         [0.6750, 0.4715, 0.9377, 0.3220],
         [0.0537, 0.1719, 0.0013, 0.0537],
         [0.2681, 0.7514, 0.6523, 0.7703]],

        [[0.5285, 0.5360, 0.7949, 0.6210],
         [0.3066, 0.1138, 0.6412, 0.4724],
         [0.3599, 0.9624, 0.0266, 0.1455],
         [0.7474, 0.2999, 0.7476, 0.2889],
         [0.1779, 0.3515, 0.8900, 0.2301]]])

假设一维类张量是 t,它给出了批次中每个示例的真实类。所以它是一个形状为 (B, ) 的一维张量,其类别标签在 {0, 1, 2, ..., C-1} 范围内。

>>> t = torch.randint(C, size = (B, ))
>>> t
tensor([3, 2, 1, 1, 0])

所以基本上你想从t的最内层维度中选择对应于a的索引。这可以使用 fancy indexingbroadcasting 组合实现,如下所示:

>>> i = torch.arange(M).reshape(M, 1, 1)
>>> j = torch.arange(B).reshape(1, B, 1)
>>> k = t.reshape(1, B, 1)

请注意,一旦您通过 (i, j, k) 索引任何内容,它们将转到 expand 并采用形状 (M, B, 1),这是所需的输出形状。 现在只需通过 aij 索引 k 给出:

>>> a[i, j, k]
tensor([[[0.3210],
         [0.8451],
         [0.8408],
         [0.8891],
         [0.5620]],

        [[0.6859],
         [0.3648],
         [0.4715],
         [0.1719],
         [0.2681]],

        [[0.6210],
         [0.6412],
         [0.9624],
         [0.2999],
         [0.1779]]])

因此,本质上,如果您事先生成传达访问模式的索引数组,则可以直接使用它们来提取张量的一些切片。