PyTorch展平不保持批量大小

时间:2020-02-07 14:42:58

标签: python pytorch

在Keras中,使用Flatten()层可保留批次大小。例如,如果Flatten的输入形状是(32, 100, 100),则在Keras中Flatten的输出形状是(32, 10000),而在PyTorch中它是320000。为什么会这样?

2 个答案:

答案 0 :(得分:5)

正如OP在其答案中已经指出的那样,张量操作没有默认考虑批处理维。您可以将torch.flatten()Tensor.flatten()start_dim=1配合使用,以在批处理尺寸之后开始展平操作。

或者,从PyTorch 1.2.0开始,您可以在模型中定义一个nn.Flatten()层,默认为start_dim=1

答案 1 :(得分:2)

是的,如this thread中所述,PyTorch操作(如Flatten,view,reshape)。

通常,当使用诸如Conv2d之类的模块时,您不必担心批处理大小。 PyTorch会照顾好它。但是,当直接处理张量时,您需要注意批量大小。

在Keras中,Flatten()是一层。但是在PyTorch中,flatten()是对张量的操作。因此,批量大小需要手动进行。