我对以下代码片段有疑问:
>>> def init_weights(m):
print(m)
if type(m) == nn.Linear:
m.weight.data.fill_(1.0)
print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
apply()是pytorch.nn包的一部分。您可以在此软件包的文档中找到代码。最后的问题: 1.尽管将init_weights()赋予apply()时未添加任何参数或方括号,但此代码示例为何起作用? 2.当函数init_weights(m)作为参数apply()不带括号和m时,函数init_weights(m)从哪里获得参数?
答案 0 :(得分:0)
我们在上述torch.nn.Module.apply(fn)
文档中找到了您的问题的答案:
将
fn
递归应用于每个子模块(由.children()返回) 以及自我。典型用途包括初始化模型的参数 (另请参见torch-nn-init)。
init_weights
不会在apply
调用之前被调用,正是因为没有括号,而是init_weights
被引用了apply
,并且仅在apply
之后init_weights
内部被调用。apply
中的每次调用都会获取其参数,并且正如文档所指出的那样,需要对net
和{{1 }}本身,这归功于方法调用net
。