这里有人知道torch.squeeze
函数是否遵守批次(例如第一个)维吗?从某些内联代码来看似乎没有。.但是也许其他人比我更了解内部工作。
顺便说一句,潜在的问题是我的形状为(n_batch, channel, x, y, 1)
。我想用一个简单的函数删除最后一个尺寸,以便最终得到(n_batch, channel, x, y)
的形状。
当然可以进行整形,甚至可以选择最后一个轴。但是我想将此功能嵌入层中,以便可以轻松地将其添加到ModuleList
或Sequence
对象中。
答案 0 :(得分:3)
不!挤压不遵守批次尺寸。如果在批处理尺寸可能为1时使用挤压,则可能会导致错误。根据经验,默认情况下,仅torch.nn中的类和函数会尊重批处理尺寸。
这过去使我头疼。 我建议使用reshape
或仅将squeeze
与可选的输入尺寸参数一起使用。根据您的情况,您可以使用.squeeze(4)
仅删除最后一个尺寸。这样就不会发生意外情况。没有输入尺寸的挤压导致我产生了意外的结果,尤其是在
nn.DataParallel
正在使用(在这种情况下,特定实例的批处理大小可能会减小到1)答案 1 :(得分:0)
接受的答案足以解决问题 - squeeze
最后一个维度。但是,我有维度 (batch, 1280, 1, 1)
的张量并且想要 (batch, 1280)
。 Squeeze
函数不允许 - squeeze(tensor, 1).shape
-> (batch, 1280, 1, 1)
和 squeeze(tensor, 2).shape
-> (batch, 1280, 1)
。我本可以使用 squeeze
两次,但你知道,美学:)
帮助我的是torch.flatten(tensor, start_dim = 1)
-> (batch, 1280)
。微不足道,但我忘记了。但是警告,这个函数我创建了一个副本而不是视图,所以要小心。
https://pytorch.org/docs/stable/generated/torch.flatten.html