当我阅读torch.nn.Module的源代码(python)时,我发现了 属性
self._modules
已在许多功能中使用,例如self.modules(), self.children()
等。但是,我没有找到任何功能 更新它。因此,self._modules
将在哪里更新? 此外,pytorch的{{1}}如何注册子模块?
nn.Module
答案 0 :(得分:1)
通常通过设置nn.module
实例的属性来注册模块和参数。
特别是,这种行为是通过对__serattr__
方法进行裁剪来实现的:
def __setattr__(self, name, value):
def remove_from(*dicts):
for d in dicts:
if name in d:
del d[name]
params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)"
.format(torch.typename(value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
if isinstance(value, Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)"
.format(torch.typename(value), name))
modules[name] = value
else:
buffers = self.__dict__.get('_buffers')
if buffers is not None and name in buffers:
if value is not None and not isinstance(value, torch.Tensor):
raise TypeError("cannot assign '{}' as buffer '{}' "
"(torch.Tensor or None expected)"
.format(torch.typename(value), name))
buffers[name] = value
else:
object.__setattr__(self, name, value)
请参见https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py来找到此方法。
答案 1 :(得分:1)
在Jiren Jin的答案中添加一些细节:
网络层(从nn.Module
继承)存储在Module._modules
中,并在__construct
中初始化:
def __init__(self):
self.__construct()
# initialize self.training separately from the rest of the internal
# state, as it is managed differently by nn.Module and ScriptModule
self.training = True
def __construct(self):
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
# ...
self._modules = OrderedDict()
self._modules
在__setattr__
中进行了更新。执行__setattr__(obj, name, value)
时将调用obj.name = value
。例如,如果在初始化从self.conv1 = nn.Conv2d(128, 256, 3, 1, 1)
继承的网络时定义nn.Module
,则将执行nn.Module.__setattr__
的以下代码:
def __setattr__(self, name, value):
def remove_from(*dicts):
for d in dicts:
if name in d:
del d[name]
params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
# ...
elif params is not None and name in params:
# ...
else:
modules = self.__dict__.get('_modules') # equivalent to modules = self._modules
if isinstance(value, Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers)
# register the given layer (nn.Conv2d) with its name (conv1)
# equivalent to self._modules['conv1'] = nn.Conv2d(128, 256, 3, 1, 1)
modules[name] = value
评论问题:
您知道火炬可以让您提供自己的前进方法这一事实吗?
如果运行从nn.Module
继承的网络的前向传递,则将调用nn.Module.__call__
,其中将调用self.forward
。但是,在实现网络时,有人已经覆盖了forward
。