火炬挤压和松开

时间:2020-05-04 18:06:38

标签: pytorch

如果这个问题已经提出,我深表歉意,但是我对pytorch的挤压和松解感到非常困惑。我试图查看文档和其他stackoverflow问题,但我仍然不确定它的实际作用。我看过What does "unsqueeze" do in Pytorch?,但还是不明白。

我试图通过自己在python中进行探索来理解它。我首先用

创建了一个随机张量
x = torch.rand(3,2,dtype=torch.float)
>>> x
tensor([[0.3703, 0.9588],
        [0.8064, 0.9716],
        [0.9585, 0.7860]])

但是无论我如何挤压,最终都会得到相同的结果:

>>> x.squeeze(0)
tensor([[0.3703, 0.9588],
        [0.8064, 0.9716],
        [0.9585, 0.7860]])
>>> x.squeeze(1)
tensor([[0.3703, 0.9588],
        [0.8064, 0.9716],
        [0.9585, 0.7860]])
>>> x.squeeze(-1)
tensor([[0.3703, 0.9588],
        [0.8064, 0.9716],
        [0.9585, 0.7860]])

如果我现在尝试松开,我得到以下信息,

>>> x.unsqueeze(1)
tensor([[[0.3703, 0.9588]],

        [[0.8064, 0.9716]],

        [[0.9585, 0.7860]]])
>>> x.unsqueeze(0)
tensor([[[0.3703, 0.9588],
         [0.8064, 0.9716],
         [0.9585, 0.7860]]])
>>> x.unsqueeze(-1)
tensor([[[0.3703],
         [0.9588]],

        [[0.8064],
         [0.9716]],

        [[0.9585],
         [0.7860]]])

但是,如果现在我创建张量x = torch.tensor([1,2,3,4]),并且尝试对其进行解压缩,那么看来1-1使其成为一列,其中0仍然是一样。

x.unsqueeze(0)
tensor([[1, 2, 3, 4]])
>>> x.unsqueeze(1)
tensor([[1],
        [2],
        [3],
        [4]])
>>> x.unsqueeze(-1)
tensor([[1],
        [2],
        [3],
        [4]])

有人可以解释张紧和张紧的过程吗?提供争论01-1有什么区别?

3 个答案:

答案 0 :(得分:7)

以下是 squeeze/unsqueeze 对有效二维矩阵的作用的直观表示:

enter image description here

当您解压张量时,您希望将其“解压”到哪个维度(如行或列等)是不明确的。 dim 参数说明了这一点 - 即要添加的新维度的位置。

因此生成的未压缩张量具有相同的信息,但用于访问它们的索引不同。

答案 1 :(得分:3)

简单地说,unsqueeze()将“ 1”的表面尺寸“添加”到张量(在指定尺寸),而squeeze将从张量中删除所有表面的1尺寸。

您应该查看张量的shape属性以轻松查看它。最后一种情况是:

import torch

tensor = torch.tensor([1, 0, 2, 3, 4])
tensor.shape # torch.Size([5])
tensor.unsqueeze(dim=0).shape # [1, 5]
tensor.unsqueeze(dim=1).shape # [5, 1]

对于将单个样本提供给网络(需要将第一维标注为批次)非常有用,对于图像而言,它将是:

# 3 channels, 32 width, 32 height
tensor = torch.randn(3, 32, 32)
# 1 batch, 3 channels, 32 width, 32 height
tensor.unsqueeze(dim=0).shape
如果您创建具有1个尺寸的unsqueeze,例如,可以看到

tensor。像这样:

# 3 channels, 32 width, 32 height and some 1 unnecessary dimensions
tensor = torch.randn(3, 1, 32, 1, 32, 1)
# 1 batch, 3 channels, 32 width, 32 height again
tensor.squeeze().unsqueeze(0) # [1, 3, 32, 32]

答案 2 :(得分:0)

  1. torch.unsqueeze(输入,暗淡)→张量
a = torch.randn(4, 4, 4)
torch.unsqueeze(a, 0).size()

>>> torch.Size([1, 4, 4, 4])
a = torch.randn(4, 4, 4)
torch.unsqueeze(a, 1).size()

>>> torch.Size([4, 1, 4, 4])
a = torch.randn(4, 4, 4)
torch.unsqueeze(a, 2).size()

>>> torch.Size([4, 4, 1, 4])
a = torch.randn(4, 4, 4)
torch.unsqueeze(a, 3).size()

>>> torch.Size([4, 4, 4, 1])
  1. torch.squeeze(input,dim = None,out = None)→张量
b = torch.randn(4, 1, 4)

>>> tensor([[[ 1.2912, -1.9050,  1.4771,  1.5517]],

        [[-0.3359, -0.2381, -0.3590,  0.0406]],

        [[-0.2460, -0.2326,  0.4511,  0.7255]],

        [[-0.1456, -0.0857, -0.8443,  1.1423]]])
b.size()

>>> torch.Size([4, 1, 4])

c = b.squeeze(1)

b
>>> tensor([[[ 1.2912, -1.9050,  1.4771,  1.5517]],

        [[-0.3359, -0.2381, -0.3590,  0.0406]],

        [[-0.2460, -0.2326,  0.4511,  0.7255]],

        [[-0.1456, -0.0857, -0.8443,  1.1423]]])

b.size()
>>> torch.Size([4, 1, 4])
c
>>> tensor([[ 1.2912, -1.9050,  1.4771,  1.5517],
        [-0.3359, -0.2381, -0.3590,  0.0406],
        [-0.2460, -0.2326,  0.4511,  0.7255],
        [-0.1456, -0.0857, -0.8443,  1.1423]])

c.size()
>>> torch.Size([4, 4])
相关问题