Module.parameters()如何找到参数?

时间:2019-02-18 16:14:05

标签: python-3.x pytorch

我注意到,只要您创建一个扩展了torch.nn.Module的新网络,就可以立即调用net.parameters()来查找与反向传播相关的参数。

import torch

class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc = torch.nn.Linear(5, 5)

    def forward(self, x):
        return self.fc(x)

net = MyNet()
print(list(net.parameters()))

但是后来我想知道,这怎么可能?我刚刚将此Linear层对象分配给成员变量,但未在其他任何地方记录它(或者是?)。 MyNet必须能够以某种方式跟踪所使用的参数,但是如何?

1 个答案:

答案 0 :(得分:2)

这很简单,只需通过元编程检查属性并检查其类型

class Example():
    def __init__(self):
        self.special_thing = nn.Parameter(torch.rand(2))
        self.something_else = "ok"

    def get_parameters(self):
        for key, value in self.__dict__.items():
            if type(value) == nn.Parameter:
                print(key, "is a parameter!")


e = Example()
e.get_parameters()
# => special_thing is a parameter!