通过顺序容器展平PyTorch构建层

时间:2017-08-09 08:01:57

标签: python conv-neural-network pytorch

我正在尝试通过PyTorch的顺序容器构建一个cnn,我的问题是我无法弄清楚如何压平图层。

main = nn.Sequential()
self._conv_block(main, 'conv_0', 3, 6, 5)
main.add_module('max_pool_0_2_2', nn.MaxPool2d(2,2))
self._conv_block(main, 'conv_1', 6, 16, 3)
main.add_module('max_pool_1_2_2', nn.MaxPool2d(2,2)) 
main.add_module('flatten', make_it_flatten)

我应该在“make_it_flatten”中加入什么? 我试图压扁主要但它不起作用,主要不存在调用视图

main = main.view(-1, 16*3*3)

2 个答案:

答案 0 :(得分:14)

这可能不是您正在寻找的,但您只需创建自己的nn.Module即可展平任何输入,然后您可以将其添加到nn.Sequential()对象:

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size()[0], -1)

x.size()[0]将选择批量暗淡,-1将计算所有剩余的暗淡以适应元素的数量,从而展平任何张量/变量。

nn.Sequential中使用它:

main = nn.Sequential()
self._conv_block(main, 'conv_0', 3, 6, 5)
main.add_module('max_pool_0_2_2', nn.MaxPool2d(2,2))
self._conv_block(main, 'conv_1', 6, 16, 3)
main.add_module('max_pool_1_2_2', nn.MaxPool2d(2,2)) 
main.add_module('flatten', Flatten())

答案 1 :(得分:1)

平整图层的最快方法是不创建新模块,而不会通过main.add_module('flatten', Flatten())将该模块添加到主模块中。

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

相反,就像我在here中所展示的那样,在模型out = inp.reshape(inp.size(0), -1)中进行一个简单的forward更快。