pytorch的nn.Module如何注册子模块?

时间:2019-03-05 02:43:35

标签: python pytorch

  

当我阅读torch.nn.Module的源代码(python)时,我发现了   属性self._modules已在许多功能中使用,例如   self.modules(), self.children()等。但是,我没有找到任何功能   更新它。因此,self._modules将在哪里更新?   此外,pytorch的{​​{1}}如何注册子模块?

nn.Module

2 个答案:

答案 0 :(得分:1)

通常通过设置nn.module实例的属性来注册模块和参数。 特别是,这种行为是通过对__serattr__方法进行裁剪来实现的:

def __setattr__(self, name, value):
        def remove_from(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]

        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            if params is None:
                raise AttributeError(
                    "cannot assign parameters before Module.__init__() call")
            remove_from(self.__dict__, self._buffers, self._modules)
            self.register_parameter(name, value)
        elif params is not None and name in params:
            if value is not None:
                raise TypeError("cannot assign '{}' as parameter '{}' "
                                "(torch.nn.Parameter or None expected)"
                                .format(torch.typename(value), name))
            self.register_parameter(name, value)
        else:
            modules = self.__dict__.get('_modules')
            if isinstance(value, Module):
                if modules is None:
                    raise AttributeError(
                        "cannot assign module before Module.__init__() call")
                remove_from(self.__dict__, self._parameters, self._buffers)
                modules[name] = value
            elif modules is not None and name in modules:
                if value is not None:
                    raise TypeError("cannot assign '{}' as child module '{}' "
                                    "(torch.nn.Module or None expected)"
                                    .format(torch.typename(value), name))
                modules[name] = value
            else:
                buffers = self.__dict__.get('_buffers')
                if buffers is not None and name in buffers:
                    if value is not None and not isinstance(value, torch.Tensor):
                        raise TypeError("cannot assign '{}' as buffer '{}' "
                                        "(torch.Tensor or None expected)"
                                        .format(torch.typename(value), name))
                    buffers[name] = value
                else:
                    object.__setattr__(self, name, value)

请参见https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py来找到此方法。

答案 1 :(得分:1)

在Jiren Jin的答案中添加一些细节:

  • 网络层(从nn.Module继承)存储在Module._modules中,并在__construct中初始化:

    def __init__(self):
        self.__construct()
        # initialize self.training separately from the rest of the internal
        # state, as it is managed differently by nn.Module and ScriptModule
        self.training = True
    
    def __construct(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        # ...
        self._modules = OrderedDict()
    
  • self._modules__setattr__中进行了更新。执行__setattr__(obj, name, value)时将调用obj.name = value。例如,如果在初始化从self.conv1 = nn.Conv2d(128, 256, 3, 1, 1)继承的网络时定义nn.Module,则将执行nn.Module.__setattr__的以下代码:

    def __setattr__(self, name, value):
        def remove_from(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]
    
        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            # ...
        elif params is not None and name in params:
            # ...
        else:
            modules = self.__dict__.get('_modules') # equivalent to modules = self._modules
            if isinstance(value, Module):
                if modules is None:
                    raise AttributeError(
                        "cannot assign module before Module.__init__() call")
                remove_from(self.__dict__, self._parameters, self._buffers)
                # register the given layer (nn.Conv2d) with its name (conv1)
                # equivalent to self._modules['conv1'] = nn.Conv2d(128, 256, 3, 1, 1)
                modules[name] = value
    

评论问题:

  

您知道火炬可以让您提供自己的前进方法这一事实吗?

如果运行从nn.Module继承的网络的前向传递,则将调用nn.Module.__call__,其中将调用self.forward。但是,在实现网络时,有人已经覆盖了forward