-1在pytorch视图中是什么意思?

时间:2018-06-11 07:21:32

标签: reshape pytorch

正如问题所说,-1在pytorch view中做了什么?

In [2]: a = torch.arange(1, 17)

In [3]: a
Out[3]:
tensor([  1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
         11.,  12.,  13.,  14.,  15.,  16.])

In [7]: a.view(-1,1)
Out[7]:
tensor([[  1.],
        [  2.],
        [  3.],
        [  4.],
        [  5.],
        [  6.],
        [  7.],
        [  8.],
        [  9.],
        [ 10.],
        [ 11.],
        [ 12.],
        [ 13.],
        [ 14.],
        [ 15.],
        [ 16.]])

In [8]: a.view(1,-1)
Out[8]:
tensor([[  1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
          11.,  12.,  13.,  14.,  15.,  16.]])

是否(-1)生成了额外的维度? 它是否与numpy reshape -1的行为相同?

5 个答案:

答案 0 :(得分:6)

是的,它的行为与-1中的numpy.reshape()相似,即推断出此维度的实际值,以便视图中的元素数量与原始元素数量相匹配。

例如:

import torch

x = torch.arange(6)

print(x.view(3, -1))      # inferred size will be 2 as 6 / 3 = 2
# tensor([[ 0.,  1.],
#         [ 2.,  3.],
#         [ 4.,  5.]])

print(x.view(-1, 6))      # inferred size will be 1 as 6 / 6 = 1
# tensor([[ 0.,  1.,  2.,  3.,  4.,  5.]])

print(x.view(1, -1, 2))   # inferred size will be 3 as 6 / (1 * 2) = 3
# tensor([[[ 0.,  1.],
#          [ 2.,  3.],
#          [ 4.,  5.]]])

# print(x.view(-1, 5))    # throw error as there's no int N so that 5 * N = 6
# RuntimeError: invalid argument 2: size '[-1 x 5]' is invalid for input with 6 elements

print(x.view(-1, -1, 3))  # throw error as only one dimension can be inferred
# RuntimeError: invalid argument 1: only one dimension can be inferred

答案 1 :(得分:3)

我喜欢本杰明给出的答案https://stackoverflow.com/a/50793899/1601580

<块引用>

是的,它在 numpy.reshape() 中的行为确实类似于 -1,即将推断此维度的实际值,以便视图中的元素数量与原始元素数量相匹配。

但我认为对您而言可能不直观(或至少对我而言不是)的奇怪案例边缘案例是在使用单个 -1 即 tensor.view(-1) 调用它时。 我的猜测是它的工作方式与往常完全相同,只是因为您提供一个数字来查看它假定您需要一个维度。如果您有 tensor.view(-1, Dnew),它将产生一个二维/索引的张量,但会根据张量的原始维度确保第一个维度的大小正确。假设你有 (D1, D2) 你有 Dnew=D1*D2 那么新的维度就是 1。

对于带有代码的真实示例,您可以运行:

import torch

x = torch.randn(1, 5)
x = x.view(-1)
print(x.size())

x = torch.randn(2, 4)
x = x.view(-1, 8)
print(x.size())

x = torch.randn(2, 4)
x = x.view(-1)
print(x.size())

x = torch.randn(2, 4, 3)
x = x.view(-1)
print(x.size())

输出:

torch.Size([5])
torch.Size([1, 8])
torch.Size([8])
torch.Size([24])

历史/背景

我觉得一个很好的例子(common case early on in pytorch before 扁平化层是 official added 是这个常见的代码):

class Flatten(nn.Module):
    def forward(self, input):
        # input.size(0) usually denotes the batch size so we want to keep that
        return input.view(input.size(0), -1)

用于顺序。在这个视图中,x.view(-1) 是一个奇怪的扁平层,但缺少挤压(即添加维度为 1)。添加或删除此压缩对于代码实际运行通常很重要。

答案 2 :(得分:1)

我猜这与np.reshape类似。

来自here

新形状应与原始形状兼容。如果是整数,则结果将是该长度的1-D数组。一个形状尺寸可以是-1。在这种情况下,该值是从数组的长度和剩余维度推断出来的。

如果您选择a = torch.arange(1, 18),可以通过a.view(-1,6)a.view(-1,9)a.view(3,-1)等各种方式查看。

答案 3 :(得分:1)

From the PyTorch documentation

>>> x = torch.randn(4, 4)
>>> x.size()
torch.Size([4, 4])
>>> y = x.view(16)
>>> y.size()
torch.Size([16])
>>> z = x.view(-1, 8)  # the size -1 is inferred from other dimensions
>>> z.size()
torch.Size([2, 8])

答案 4 :(得分:0)

例如,

-1推断为2,如果有

>>> a = torch.rand(4,4)
>>> a.size()
torch.size([4,4])
>>> y = x.view(16)
>>> y.size()
torch.size([16])
>>> z = x.view(-1,8) # -1 is generally inferred as 2  i.e (2,8)
>>> z.size()
torch.size([2,8])