我有编码解码功能。我无法解码图像以使最终输出与输入大小一致,即[5,3,32,32]。 如何在解码器中重建图像,以使输入和输出的图像尺寸一致? 请快点!!!
from torch import nn
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
class UnFlatten(nn.Module):
def forward(self, input, size=512):
return input.view(input.size(0), size, 1, 1)
net = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=4, stride=2),
nn.ReLU(),
#nn.Conv2d(128, 256, kernel_size=4, stride=2),
#nn.ReLU(),
Flatten(),
nn.Linear(512, 32),
)
net2= nn.Sequential(
nn.Linear(32, 512),
UnFlatten(),
nn.ConvTranspose2d(512, 128, kernel_size=5, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(32, 3, kernel_size=6, stride=2),
nn.Sigmoid(),
)
input = torch.zeros(5,3,32,32)
mu=net(input)
print("mu shape", mu.shape)
mu2= net2(mu)
print("mu2 shape", mu2.shape)