如何使用PyTorch修剪/移除张量的一部分以匹配其他张量的形状?

时间:2020-04-09 13:55:22

标签: python pytorch tensor

我有2个张量:

outputs: torch.Size([4, 27, 161])       pred: torch.Size([4, 30, 161])

我想(从最后开始)切割pred,使其具有与outputs相同的尺寸。

使用PyTorch的最佳方法是什么?

2 个答案:

答案 0 :(得分:1)

您可以使用Narrow

例如:

a = torch.randn(4,30,161)
a.size() # torch.Size([4, 30, 161])
a.narrow(1,0,27).size() # torch.Size([4, 27, 161])

答案 1 :(得分:0)

如果你有两个张量的固定维数,你可以试试这个:

a = torch.randn(3, 5)
b = torch.zeros(3, 2)
b_h, b_w = b.shape
c = a[:b_h, :b_w]  # torch.Size([3, 2])

c 的形状与 b 相同,但值与 a 相同。