在Pytorch中使用跳过连接实现U-net

时间:2019-04-23 13:33:49

标签: transform pytorch

我想在pytorch中实现带有跳过连接的Unet。我的香草网络看起来像这样:

def createVanillaGan(self):

    # Encoding layers
    self.conv_1 = nn.Sequential(
    nn.Conv2d(1, 64, 3, stride=2, padding=1, bias=False), # in channel, out channel, filter kernel size
    nn.BatchNorm2d( 64 ),
    nn.LeakyReLU( 0.1 )
    )

    self.conv_2 = nn.Sequential(
    nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False), # in channel, out channel, filter kernel size
    nn.BatchNorm2d( 128 ),
    nn.LeakyReLU( 0.1 )
    )
    .
    .
    .
    self.conv_trans_4 = nn.Sequential(
    nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    )

    self.conv_trans_5 = nn.Sequential(
    nn.ConvTranspose2d(64, 3, 3, stride=2, padding=1, output_padding=1, bias=False),
    #nn.BatchNorm2d(3),
    nn.Tanh()
    )

我的前向通行证看起来像这样

def forward(self, data):
    output1 = self.conv_1(data) #10x64x128x128
    output2 = self.conv_2(output1) #10x128x64x64
    output3 = self.conv_3(output2) #10x256x32x32
    .
    .
    .
    #decoding
    output4_de = self.conv_trans_4(output3_de) #10x64x128x128

    output5_de = self.conv_trans_5(output4_de) #10x128x64x64

我想要做的是在正向传递中连接输出。我可以只在前向传递中执行torch.cat((output5,output5_de),1),还是需要对createVanillaGan(self)进行更改?我想知道这对反向传播有什么影响,或者我是否可以更改正向通过并完成?

0 个答案:

没有答案