如何获得Unet的内部模块?

时间:2017-09-20 06:56:24

标签: pytorch

我使用UnetGenerator创建了UNet。您可以找到结果结构here

  1. 如何获取模块Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  2. 如何获取模块(5): ConvTranspose2d(256, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  3. 我想让内部模块更改某些图层的属性。 我试过像net.modules(i).modules(i)这样的东西,但它没有用。我参考了文档,我没有一个好主意去做。

    我的初衷是在训练时更改某些图层的属性。我可以添加自定义图层myLayer,其中self.mode='normal'。在训练时,我希望我可以更改其属性myLayer.mode = 'capture',以使其在训练中改变其行为。

2 个答案:

答案 0 :(得分:2)

nn.Module的所有子类都有一个名为children的属性。您可以使用以下代码访问该文件。

unet = UnetGenerator(512,512,4)
layers = list(unet.children())
len(layers)

对于我使用上述代码创建的网络,我可以访问网络中的一个层并更改如下所示的属性。

l = layers[0]
conv = list(l.children())[0][0]
conv.kernel_size = (2,2)

如果您在不使用任何预先训练的权重的情况下训练网络,那么您可以在创建网络对象之前对源代码进行更改。

答案 1 :(得分:1)

添加Vishnu的正确答案:

如果要进行固定更改,最好在创建网络对象之前更改代码。

但是,关于pytorch的一个重要事项是,您可以动态更改属性:大多数深度学习库(如tensorflow)出于性能原因使用静态图。这意味着网络图形构建一次并在每次正向传递时执行。

另一方面,

pytorch使用动态计算图,这意味着对于每个正向传递,图形是动态构建的。这使您有机会根据一些不同的参数更改网络体系结构。 例如,如果您的损失低于某个值或您的纪元是奇数(我并不暗示这些是合理的应用程序),您可以减少内核大小。所有这些动态更改都应该在forward函数的实现中发生,您可以向其传递其他参数,例如:

class MyNet(nn.Module):
    def __init__(input_nc, output_nc):
        super().__init__()
        # define layers 
        # ...
        self.choice_A = nn.Conv2d(input_nc, output_nc, kernel_size = 3)
        self.choice_B = nn.Conv2d(input_nc, output_nc, kernel_size = 4)
        # continue init

    def forward(self, x, epoch):
        # start forward pass
        # ...
        if epoch % 2 == 0:
            x = self.choice_A(x)
        else:
            x = self.choice_B(x)
        # ...
        return x