我有一个尺寸为M
的3D张量[BxLxD]
和一个尺寸为idx
的1D张量[B,1]
,其中包含范围为(0, L-1)
的列索引。我想创建一个尺寸为N
的2D张量[BxD]
,使得N[i,j] = M[i, idx[i], j]
。如何有效地做到这一点?
示例:
B,L,D = 2,4,2
M = torch.rand(B,L,D)
>
tensor([[[0.0612, 0.7385],
[0.7675, 0.3444],
[0.9129, 0.7601],
[0.0567, 0.5602]],
[[0.5450, 0.3749],
[0.4212, 0.9243],
[0.1965, 0.9654],
[0.7230, 0.6295]]])
idx = torch.randint(0, L, size = (B,))
>
tensor([3, 0])
N = get_N(M, idx)
Expected output:
>
tensor([[0.0567, 0.5602],
[0.5450, 0.3749]])
谢谢。
答案 0 :(得分:1)
import torch
B,L,D = 2,4,2
def get_N(M, idx):
return M[torch.arange(B), idx, :].squeeze()
M = torch.tensor([[[0.0612, 0.7385],
[0.7675, 0.3444],
[0.9129, 0.7601],
[0.0567, 0.5602]],
[[0.5450, 0.3749],
[0.4212, 0.9243],
[0.1965, 0.9654],
[0.7230, 0.6295]]])
idx = torch.tensor([3,0])
N = get_N(M, idx)
print(N)
结果:
tensor([[0.0567, 0.5602],
[0.5450, 0.3749]])
沿二维切片。