我有一个张量inp
,其张量为:torch.Size([4, 122, 161])
。
我还有一个mask
,大小为:torch.Size([4, 122])
。
我的mask
中的每个元素如下所示:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
device='cuda:0', grad_fn=<SelectBackward>)
因此,我想将inp
修剪为沿维度= 1缩小,以仅存在于mask
具有1
的地方。在所示的情况下,有23个1
,因此我希望inp
的大小为:torch.Size([4, 23, 161])
答案 0 :(得分:1)
我认为高级索引会起作用。 (我假设每个面罩平均有23 1s)
inp_trimmed = inp[mask.type(torch.bool)].reshape(4,23,161)