Pytorch:了解nn.Module类在内部如何工作

时间:2019-11-11 04:34:48

标签: python deep-learning pytorch

通常,nn.Module可由子类继承,如下所示。

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)  # 

class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.fc1 = nn.Linear(20, 1)
        self.apply(init_weights)

    def forward(self, x):
        x = self.fc1(x)
        return x

我的第一个问题是,为什么即使我的__init__没有training_signals的任何正弦参数,为什么我也可以简单地运行下面的代码,并且好像training_signals传递给了{ {1}}方法。如何运作?

forward()

第二个问题是model = LinearRegression() training_signals = torch.rand(1000,20) model(training_signals) 在内部如何工作?是在调用self.apply(init_weights)方法之前执行的吗?

1 个答案:

答案 0 :(得分:3)

  

问题1:为什么即使我的__init__没有training_signals的任何位置参数,我也可以简单地在下面运行代码,看起来training_signals传递给了{{1} } 方法。如何运作?

首先,运行此行时将调用forward()

__init__

如您所见,您不传递任何参数,也不应该传递。 model = LinearRegression() 的签名与基类(运行__init__时调用的基类)的签名相同。如您所见heresuper(LinearRegression, self).__init__()的初始化签名就是nn.Module(就像您的签名一样)。

第二,def __init__(self)现在是一个对象。当您运行以下行时:

model

您实际上是在调用model(training_signals) 方法并传递__call__作为位置参数。如您所见here,除其他外,training_signals方法调用__call__方法:

forward

result = self.forward(*input, **kwargs) 的所有参数(位置参数和命名参数)传递给__call__

  

第二季度:forward在内部如何运作?是在调用forward方法之前执行的吗?

PyTorch是开源的,因此您只需转到源代码并进行检查。如您所见here,实现非常简单:

self.apply(init_weights)

引用函数的文档:它“ 递归地将def apply(self, fn): for module in self.children(): module.apply(fn) fn(self) return self 应用于fn.children() 的每个子模块”。根据实现,您还可以了解要求:

  • self必须是可通话的;
  • fn仅接收一个fn对象作为输入;