在Keras中,使用Flatten()
层可保留批次大小。例如,如果Flatten的输入形状是(32, 100, 100)
,则在Keras
中Flatten的输出形状是(32, 10000)
,而在PyTorch中它是320000
。为什么会这样?
答案 0 :(得分:5)
正如OP在其答案中已经指出的那样,张量操作没有默认考虑批处理维。您可以将torch.flatten()
或Tensor.flatten()
与start_dim=1
配合使用,以在批处理尺寸之后开始展平操作。
或者,从PyTorch 1.2.0开始,您可以在模型中定义一个nn.Flatten()
层,默认为start_dim=1
。
答案 1 :(得分:2)
是的,如this thread中所述,PyTorch操作(如Flatten,view,reshape)。
通常,当使用诸如Conv2d
之类的模块时,您不必担心批处理大小。 PyTorch会照顾好它。但是,当直接处理张量时,您需要注意批量大小。
在Keras中,Flatten()
是一层。但是在PyTorch中,flatten()
是对张量的操作。因此,批量大小需要手动进行。