重塑时保留尺寸

时间:2019-07-09 15:03:24

标签: pytorch reshape

比方说,我们有一个B x C x W x H大小的张量(这对于成批图像来说很常见),我们想将其重整为B x M,其中M = C*W*H。是否有一种内置的方式来执行此操作而不显式提及B

如果我们事先知道B,即使没有明确知道三个C,W,H中的任何一个,我们也可以这样做:

a = torch.randn(20,3,512,512)
b = a.reshape((20, -1)) #we can use -1 to infer the dimension `M`

但是我们还能在不知道B的情况下这样做吗?

(我知道我们显然可以使用B找到B = a.shape[0],但我的问题是,是否可能不知道B也可以。)

2 个答案:

答案 0 :(得分:0)

仅有的另一种方法是计算第二维,并使用-1作为第一维。

a = torch.randn(20,3,512,512)
print(a.shape)
b = a.reshape((20, -1)) 
print(b.shape)
b = a.reshape((-1, 786432)) # 3*512*512
print(b.shape)

torch.Size([20, 3, 512, 512])
torch.Size([20, 786432])
torch.Size([20, 786432])

因为整形时只能有一个-1

答案 1 :(得分:0)

原则上,您可以通过简单地使用输入的一维来使它成为一个通用功能,可以处理任何批量大小,例如:

a = torch.randn(20, 3, 512, 512)
b = a.reshape((a.shape[0], -1))

您可以将其包装在函数中,并在必要时调用它。