火炬挤压和批量尺寸

时间:2020-03-10 14:39:32

标签: pytorch torch

这里有人知道torch.squeeze函数是否遵守批次(例如第一个)维吗?从某些内联代码来看似乎没有。.但是也许其他人比我更了解内部工作。

顺便说一句,潜在的问题是我的形状为(n_batch, channel, x, y, 1)。我想用一个简单的函数删除最后一个尺寸,以便最终得到(n_batch, channel, x, y)的形状。

当然可以进行整形,甚至可以选择最后一个轴。但是我想将此功能嵌入层中,以便可以轻松地将其添加到ModuleListSequence对象中。

2 个答案:

答案 0 :(得分:3)

不!挤压不遵守批次尺寸。如果在批处理尺寸可能为1时使用挤压,则可能会导致错误。根据经验,默认情况下,仅torch.nn中的类和函数会尊重批处理尺寸。

这过去使我头疼。 我建议使用reshape或仅将squeeze与可选的输入尺寸参数一起使用。根据您的情况,您可以使用.squeeze(4)仅删除最后一个尺寸。这样就不会发生意外情况。没有输入尺寸的挤压导致我产生了意外的结果,尤其是在

  1. 模型的输入形状可能会变化
  2. 批量大小可能有所不同
  3. nn.DataParallel正在使用(在这种情况下,特定实例的批处理大小可能会减小到1)

答案 1 :(得分:0)

接受的答案足以解决问题 - squeeze 最后一个维度。但是,我有维度 (batch, 1280, 1, 1) 的张量并且想要 (batch, 1280)Squeeze 函数不允许 - squeeze(tensor, 1).shape -> (batch, 1280, 1, 1)squeeze(tensor, 2).shape -> (batch, 1280, 1)。我本可以使用 squeeze 两次,但你知道,美学:)

帮助我的是torch.flatten(tensor, start_dim = 1) -> (batch, 1280)。微不足道,但我忘记了。但是警告,这个函数我创建了一个副本而不是视图,所以要小心。

https://pytorch.org/docs/stable/generated/torch.flatten.html