所以我有一个 M x B x C 的张量,其中 M 是模型的数量,B 是批次,C 是类别,每个单元格是给定模型和批次的类别概率。然后我有一个正确答案的张量,它只是一个大小为 B 的一维,我们称之为“t”。如何使用大小为 B 的 1D 只返回 M x B x 1,其中返回的张量只是正确类的值?假设 M x B x C 张量被称为“blah”,我试过
blah[:, :, C]
for i in range(M):
blah[i, :, C]
blah[:, C, :]
前 2 个只返回每个切片的第 3 个维度中索引 t 的值。最后一个返回第二维中 t 个索引处的值。我该怎么做?
答案 0 :(得分:1)
我们可以通过combining advanced and basic indexing
得到想要的结果import torch
# shape [2, 3, 4]
blah = torch.tensor([
[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
# shape [3]
t = torch.tensor([2, 1, 0])
b = torch.arange(blah.shape[1]).type_as(t)
# shape [2, 3, 1]
result = blah[:, b, t].unsqueeze(-1)
结果
>>> result
tensor([[[ 2],
[ 5],
[ 8]],
[[14],
[17],
[20]]])
答案 1 :(得分:1)
您只需要通过:
range(B)
作为第二个切片blah[:,range(B),t]
答案 2 :(得分:0)
这是一种方法:
假设 a
是您的 M x B x C
形张量。我在下面取一些有代表性的值,
>>> M = 3
>>> B = 5
>>> C = 4
>>> a = torch.rand(M, B, C)
>>> a
tensor([[[0.6222, 0.6703, 0.0057, 0.3210],
[0.6251, 0.3286, 0.8451, 0.5978],
[0.0808, 0.8408, 0.3795, 0.4872],
[0.8589, 0.8891, 0.8033, 0.8906],
[0.5620, 0.5275, 0.4272, 0.2286]],
[[0.2419, 0.0179, 0.2052, 0.6859],
[0.1868, 0.7766, 0.3648, 0.9697],
[0.6750, 0.4715, 0.9377, 0.3220],
[0.0537, 0.1719, 0.0013, 0.0537],
[0.2681, 0.7514, 0.6523, 0.7703]],
[[0.5285, 0.5360, 0.7949, 0.6210],
[0.3066, 0.1138, 0.6412, 0.4724],
[0.3599, 0.9624, 0.0266, 0.1455],
[0.7474, 0.2999, 0.7476, 0.2889],
[0.1779, 0.3515, 0.8900, 0.2301]]])
假设一维类张量是 t
,它给出了批次中每个示例的真实类。所以它是一个形状为 (B, )
的一维张量,其类别标签在 {0, 1, 2, ..., C-1}
范围内。
>>> t = torch.randint(C, size = (B, ))
>>> t
tensor([3, 2, 1, 1, 0])
所以基本上你想从t
的最内层维度中选择对应于a
的索引。这可以使用 fancy indexing 和 broadcasting 组合实现,如下所示:
>>> i = torch.arange(M).reshape(M, 1, 1)
>>> j = torch.arange(B).reshape(1, B, 1)
>>> k = t.reshape(1, B, 1)
请注意,一旦您通过 (i, j, k)
索引任何内容,它们将转到 expand 并采用形状 (M, B, 1)
,这是所需的输出形状。
现在只需通过 a
、i
和 j
索引 k
给出:
>>> a[i, j, k]
tensor([[[0.3210],
[0.8451],
[0.8408],
[0.8891],
[0.5620]],
[[0.6859],
[0.3648],
[0.4715],
[0.1719],
[0.2681]],
[[0.6210],
[0.6412],
[0.9624],
[0.2999],
[0.1779]]])
因此,本质上,如果您事先生成传达访问模式的索引数组,则可以直接使用它们来提取张量的一些切片。