我有一个尺寸为M
的张量[NxQxD]
和一个索引为idx
(大小为N
)的1d张量。我想高效地创建尺寸为mask
的张量[NxQxD]
,使得mask[i,j,k] = 1 iff j <= idx[i]
,即我只想在idx[i]
中保留Q
的第一个尺寸。每行M
的{{1}}的第二维(dim = 1)。
谢谢!
答案 0 :(得分:0)
事实证明,这可以通过广播技巧来完成:
mask_2d = torch.arange(Q)[None, :] < idx[:, None] #(N,Q)
mask_3d = mask[..., None] #(N,Q,1)
masked = mask.float() * data