我注意到,只要您创建一个扩展了torch.nn.Module
的新网络,就可以立即调用net.parameters()
来查找与反向传播相关的参数。
import torch
class MyNet(torch.nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.fc = torch.nn.Linear(5, 5)
def forward(self, x):
return self.fc(x)
net = MyNet()
print(list(net.parameters()))
但是后来我想知道,这怎么可能?我刚刚将此Linear
层对象分配给成员变量,但未在其他任何地方记录它(或者是?)。 MyNet
必须能够以某种方式跟踪所使用的参数,但是如何?
答案 0 :(得分:2)
这很简单,只需通过元编程检查属性并检查其类型
class Example():
def __init__(self):
self.special_thing = nn.Parameter(torch.rand(2))
self.something_else = "ok"
def get_parameters(self):
for key, value in self.__dict__.items():
if type(value) == nn.Parameter:
print(key, "is a parameter!")
e = Example()
e.get_parameters()
# => special_thing is a parameter!