如何索引具有n个维度的张量t
和m个t
的最后一个维度?对于尺寸m之前的所有尺寸,index
张量的形状等于张量t
。换句话说,我想索引张量的中间维度,同时保留选定索引的以下所有维度。
例如,假设我们有两个张量:
t = torch.randn([3, 5, 2]) * 10
index = torch.tensor([[1, 3],[0,4],[3,2]]).long()
带有t:
tensor([[[ 15.2165, -7.9702],
[ 0.6646, 5.2844],
[-22.0657, -5.9876],
[ -9.7319, 11.7384],
[ 4.3985, -6.7058]],
[[-15.6854, -11.9362],
[ 11.3054, 3.3068],
[ -4.7756, -7.4524],
[ 5.0977, -17.3831],
[ 3.9152, -11.5047]],
[[ -5.4265, -22.6456],
[ 1.6639, 10.1483],
[ 13.2129, 3.7850],
[ 3.8543, -4.3496],
[ -8.7577, -12.9722]]])
然后我想要的输出将具有(3, 2, 2)
的形状,并且是:
tensor([[[ 0.6646, 5.2844],
[ -9.7319, 11.7384]],
[[-15.6854, -11.9362],
[ 3.9152, -11.5047]],
[[ 3.8543, -4.3496],
[ 13.2129, 3.7850]]])
另一个示例是我有一个形状为t
的张量(40, 10, 6, 2)
和一个形状为(40, 10, 3)
的索引张量。这应该查询张量t
的维度3,并且预期的输出形状将为(40, 10, 3, 2)
。
如何在不使用循环的情况下以通用方式实现这一目标?
答案 0 :(得分:1)
在这种情况下,您可以执行以下操作:
t[torch.arange(t.shape[0]).unsqueeze(1), index, ...]
完整代码:
import torch
t = torch.tensor([[[ 15.2165, -7.9702],
[ 0.6646, 5.2844],
[-22.0657, -5.9876],
[ -9.7319, 11.7384],
[ 4.3985, -6.7058]],
[[-15.6854, -11.9362],
[ 11.3054, 3.3068],
[ -4.7756, -7.4524],
[ 5.0977, -17.3831],
[ 3.9152, -11.5047]],
[[ -5.4265, -22.6456],
[ 1.6639, 10.1483],
[ 13.2129, 3.7850],
[ 3.8543, -4.3496],
[ -8.7577, -12.9722]]])
index = torch.tensor([[1, 3],[0,4],[3,2]]).long()
output = t[torch.arange(t.shape[0]).unsqueeze(1), index, ...]
# tensor([[[ 0.6646, 5.2844],
# [ -9.7319, 11.7384]],
#
# [[-15.6854, -11.9362],
# [ 3.9152, -11.5047]],
#
# [[ 3.8543, -4.3496],
# [ 13.2129, 3.7850]]])