如何在Pytorch的`nn.Sequential`中变平输入

时间:2018-12-28 03:49:04

标签: python neural-network artificial-intelligence pytorch

如何使nn.Sequential内的输入变平

Model = nn.Sequential(x.view(x.shape[0],-1),
                     nn.Linear(784,256),
                     nn.ReLU(),
                     nn.Linear(256,128),
                     nn.ReLU(),
                     nn.Linear(128,64),
                     nn.ReLU(),
                     nn.Linear(64,10),
                     nn.LogSoftmax(dim=1))

3 个答案:

答案 0 :(得分:4)

您可以如下创建一个新的模块/类,然后像使用其他模块一样依次使用它(调用class f { var x; constructor(parameter) { x = parameter; } getX() { return x; } } )。

Flatten()

参考:https://discuss.pytorch.org/t/flatten-layer-of-pytorch-build-by-sequential-container/5983

答案 1 :(得分:2)

您可以按以下步骤修改代码,

Model = nn.Sequential(nn.Flatten(0, -1),
                     nn.Linear(784,256),
                     nn.ReLU(),
                     nn.Linear(256,128),
                     nn.ReLU(),
                     nn.Linear(128,64),
                     nn.ReLU(),
                     nn.Linear(64,10),
                     nn.LogSoftmax(dim=1))

答案 2 :(得分:0)

按定义flatten method

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

的速度可与view()媲美,但是reshape甚至更快。

import torch.nn as nn

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

flatten = Flatten()

t = torch.Tensor(3,2,2).random_(0, 10)
print(t, t.shape)


#https://pytorch.org/docs/master/torch.html#torch.flatten
f = torch.flatten(t, start_dim=1, end_dim=-1)
print(f, f.shape)


#https://pytorch.org/docs/master/torch.html#torch.view
f = t.view(t.size(0), -1)
print(f, f.shape)


#https://pytorch.org/docs/master/torch.html#torch.reshape
f = t.reshape(t.size(0), -1)
print(f, f.shape)

速度检查

# flatten 3.49 µs ± 146 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# view 3.23 µs ± 228 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# reshape 3.04 µs ± 93 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

如果我们要使用上面的类

flatten = Flatten()
t = torch.Tensor(3,2,2).random_(0, 10)
%timeit f=flatten(t)


5.16 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

此结果表明,创建类的方法较慢。这就是为什么更快地将张量内部展平的原因。我认为这是他们没有晋升nn.Flatten的主要原因。

所以我的建议是使用内部前进速度。像这样:

out = inp.reshape(inp.size(0), -1)