如何在PyTorch中缩小张量的最后一个维度?

时间:2019-04-02 09:26:24

标签: arrays python-3.x numpy pytorch tensor

我的张量为(1, 3, 256, 256, 3)。我需要缩小尺寸之一以获得形状(1, 3, 256, 256)。我该怎么做?

谢谢!

1 个答案:

答案 0 :(得分:0)

如果您打算在最后一个维度上应用均值,则可以执行以下操作:

In [18]: t = torch.randn((1, 3, 256, 256, 3))

In [19]: t.shape
Out[19]: torch.Size([1, 3, 256, 256, 3])

# apply mean over the last dimension
In [23]: t_reduced = torch.mean(t, -1)

In [24]: t_reduced.shape
Out[24]: torch.Size([1, 3, 256, 256])

# equivalently
In [32]: torch.mean(t, t.ndimension()-1).shape
Out[32]: torch.Size([1, 3, 256, 256])