pytorch跳过顺序模型中的连接

时间:2018-08-09 17:55:05

标签: python deep-learning conv-neural-network pytorch sequential

我试图绕过连续模型中的跳过连接。使用功能性API,我将做的事情很简单(快速示例,也许在语法上不是100%正确,但应该可以理解):

x1 = self.conv1(inp)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)

x = self.deconv4(x)
x = self.deconv3(x)
x = self.deconv2(x)
x = torch.cat((x, x1), 1))
x = self.deconv1(x)

我现在正在使用顺序模型,并尝试做类似的事情,创建一个跳过连接,将第一个conv层的激活一直带到最后一个convTranspose。我看了实现here的U-net架构,这有点令人困惑,它的功能如下:

upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                    kernel_size=4, stride=2,
                                    padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]

if use_dropout:
    model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
    model = down + [submodule] + up

这不是只是按顺序向顺序模型中添加图层吗?有down转换,之后是submodule(递归地添加内层),然后串联到up,它是upconv层。我可能缺少关于Sequential API的工作方式的重要信息,但是从U-NET截断的代码实际上是如何实现跳过的呢?

1 个答案:

答案 0 :(得分:4)

您的观察是正确的,但是您可能错过了UnetSkipConnectionBlock.forward()的定义(UnetSkipConnectionBlock是定义您共享的U-Net块的Module),这可能会澄清此实现:

(来自pytorch-CycleGAN-and-pix2pix/models/networks.py#L259

# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
#   |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):

    # ...

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)

最后一行是键(适用于所有内部块)。只需通过将输入x和(递归)块输出self.model(x)与您提到的操作列表self.model串联在一起即可完成跳过层,因此与{{ 1}}您编写的代码。