遮罩PyTorch中3D张量中的前k个元素(每行不同k)

时间:2019-09-04 12:28:52

标签: python pytorch tensor

我有一个尺寸为M的张量[NxQxD]和一个索引为idx(大小为N)的1d张量。我想高效地创建尺寸为mask的张量[NxQxD],使得mask[i,j,k] = 1 iff j <= idx[i],即我只想在idx[i]中保留Q的第一个尺寸。每行M的{​​{1}}的第二维(dim = 1)。

谢谢!

1 个答案:

答案 0 :(得分:0)

事实证明,这可以通过广播技巧来完成:

mask_2d = torch.arange(Q)[None, :] < idx[:, None] #(N,Q)
mask_3d = mask[..., None] #(N,Q,1)
masked = mask.float() * data