PyTorch-调用super的forward()方法

时间:2019-02-18 18:02:38

标签: python pytorch super

调用父级forward()的{​​{1}}方法的最合适方法是什么?例如,如果我将Module模块作为子类,则可以执行以下操作

nn.Linear

但是,docs说不要直接调用class LinearWithOtherStuff(nn.Linear): def forward(self, x): y = super(Linear, self).forward(x) z = do_other_stuff(y) return z 方法:

  

尽管需要在此函数中定义前向传递的方法,但应随后调用Module实例,而不是调用此实例,因为前者负责运行已注册的钩子,而后者则静默地忽略它们。

这使我认为forward()可能会导致某些意外错误。这是真的吗?还是我误解了继承?

1 个答案:

答案 0 :(得分:1)

TLDR;

即使带有钩子,也可以自由使用super().forward(...),甚至可以使用super()实例中注册的钩子。

说明

如上所述by this answer __call__在这里,因此将运行已注册的钩子(例如register_forward_hook)。

如果您继承并想重用基类的forward,例如这个:

import torch


class Parent(torch.nn.Module):
    def forward(self, tensor):
        return tensor + 1


class Child(Parent):
    def forward(self, tensor):
        return super(Child, self).forward(tensor) + 1


module = Child()
# Increment output by 1 so we should get `4`
module.register_forward_hook(lambda module, input, output: output + 1)
print(module(torch.tensor(1))) # and it is 4 indeed
print(module.forward(torch.tensor(1))) # here it is 3 still

如果您调用__call__方法就很好了,forward不会运行该钩子(因此您会得到3)。

您不太可能希望在register_hook的实例上使用super,但让我们考虑这样的示例:

def increment_by_one(module, input, output):
    return output + 1


class Parent(torch.nn.Module):
    def forward(self, tensor):
        return tensor + 1


class Child(Parent):
    def forward(self, tensor):
        # Increment by `1` from Parent
        super().register_forward_hook(increment_by_one)
        return super().forward(tensor) + 1


module = Child()
# Increment output by 1 so we should get `5` in total
module.register_forward_hook(increment_by_one)
print(module(torch.tensor(1)))  # and it is 5 indeed
print(module.forward(torch.tensor(1)))  # here is 3

使用super().forward(...)很好,即使钩子也可以正常工作(这是使用__call__而不是forward的主要思想)。

顺便说一句。调用super().__call__(...)会引发InifiniteRecursion错误。