我想在pytorch中实现带有跳过连接的Unet。我的香草网络看起来像这样:
def createVanillaGan(self):
# Encoding layers
self.conv_1 = nn.Sequential(
nn.Conv2d(1, 64, 3, stride=2, padding=1, bias=False), # in channel, out channel, filter kernel size
nn.BatchNorm2d( 64 ),
nn.LeakyReLU( 0.1 )
)
self.conv_2 = nn.Sequential(
nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False), # in channel, out channel, filter kernel size
nn.BatchNorm2d( 128 ),
nn.LeakyReLU( 0.1 )
)
.
.
.
self.conv_trans_4 = nn.Sequential(
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(),
)
self.conv_trans_5 = nn.Sequential(
nn.ConvTranspose2d(64, 3, 3, stride=2, padding=1, output_padding=1, bias=False),
#nn.BatchNorm2d(3),
nn.Tanh()
)
我的前向通行证看起来像这样
def forward(self, data):
output1 = self.conv_1(data) #10x64x128x128
output2 = self.conv_2(output1) #10x128x64x64
output3 = self.conv_3(output2) #10x256x32x32
.
.
.
#decoding
output4_de = self.conv_trans_4(output3_de) #10x64x128x128
output5_de = self.conv_trans_5(output4_de) #10x128x64x64
我想要做的是在正向传递中连接输出。我可以只在前向传递中执行torch.cat((output5,output5_de),1),还是需要对createVanillaGan(self)进行更改?我想知道这对反向传播有什么影响,或者我是否可以更改正向通过并完成?