比方说,我们有一个B x C x W x H
大小的张量(这对于成批图像来说很常见),我们想将其重整为B x M
,其中M = C*W*H
。是否有一种内置的方式来执行此操作而不显式提及B
?
如果我们事先知道B
,即使没有明确知道三个C,W,H
中的任何一个,我们也可以这样做:
a = torch.randn(20,3,512,512)
b = a.reshape((20, -1)) #we can use -1 to infer the dimension `M`
但是我们还能在不知道B
的情况下这样做吗?
(我知道我们显然可以使用B
找到B = a.shape[0]
,但我的问题是,是否可能不知道B
也可以。)
答案 0 :(得分:0)
仅有的另一种方法是计算第二维,并使用-1作为第一维。
a = torch.randn(20,3,512,512)
print(a.shape)
b = a.reshape((20, -1))
print(b.shape)
b = a.reshape((-1, 786432)) # 3*512*512
print(b.shape)
torch.Size([20, 3, 512, 512])
torch.Size([20, 786432])
torch.Size([20, 786432])
因为整形时只能有一个-1
。
答案 1 :(得分:0)
原则上,您可以通过简单地使用输入的一维来使它成为一个通用功能,可以处理任何批量大小,例如:
a = torch.randn(20, 3, 512, 512)
b = a.reshape((a.shape[0], -1))
您可以将其包装在函数中,并在必要时调用它。