自动编码器匹配输出到输入

时间:2018-12-01 23:03:21

标签: python pytorch autoencoder

我有编码解码功能。我无法解码图像以使最终输出与输入大小一致,即[5,3,32,32]。 如何在解码器中重建图像,以使输入和输出的图像尺寸一致? 请快点!!!

from torch import nn
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class UnFlatten(nn.Module):
    def forward(self, input, size=512):
        return input.view(input.size(0), size, 1, 1)


net = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2),
            nn.ReLU(),

            #nn.Conv2d(128, 256, kernel_size=4, stride=2),
            #nn.ReLU(),
            Flatten(),
            nn.Linear(512, 32),

        )
net2= nn.Sequential(
            nn.Linear(32, 512),
            UnFlatten(),
            nn.ConvTranspose2d(512, 128, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=6, stride=2),
            nn.Sigmoid(),
        )

input = torch.zeros(5,3,32,32)

mu=net(input)

print("mu shape", mu.shape)

mu2= net2(mu)

print("mu2 shape", mu2.shape)

0 个答案:

没有答案