如何使用PyTorch基于蒙版修剪张量?

时间:2020-03-27 17:22:27

标签: python pytorch

我有一个张量inp,其张量为:torch.Size([4, 122, 161])

我还有一个mask,大小为:torch.Size([4, 122])

我的mask中的每个元素如下所示:

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0', grad_fn=<SelectBackward>)

因此,我想将inp修剪为沿维度= 1缩小,以仅存在于mask具有1的地方。在所示的情况下,有23个1,因此我希望inp的大小为:torch.Size([4, 23, 161])

1 个答案:

答案 0 :(得分:1)

我认为高级索引会起作用。 (我假设每个面罩平均有23 1s)

inp_trimmed = inp[mask.type(torch.bool)].reshape(4,23,161)