调用父级forward()
的{{1}}方法的最合适方法是什么?例如,如果我将Module
模块作为子类,则可以执行以下操作
nn.Linear
但是,docs说不要直接调用class LinearWithOtherStuff(nn.Linear):
def forward(self, x):
y = super(Linear, self).forward(x)
z = do_other_stuff(y)
return z
方法:
尽管需要在此函数中定义前向传递的方法,但应随后调用Module实例,而不是调用此实例,因为前者负责运行已注册的钩子,而后者则静默地忽略它们。
这使我认为forward()
可能会导致某些意外错误。这是真的吗?还是我误解了继承?
答案 0 :(得分:1)
即使带有钩子,也可以自由使用super().forward(...)
,甚至可以使用super()
实例中注册的钩子。
如上所述by this answer __call__
在这里,因此将运行已注册的钩子(例如register_forward_hook
)。
如果您继承并想重用基类的forward
,例如这个:
import torch
class Parent(torch.nn.Module):
def forward(self, tensor):
return tensor + 1
class Child(Parent):
def forward(self, tensor):
return super(Child, self).forward(tensor) + 1
module = Child()
# Increment output by 1 so we should get `4`
module.register_forward_hook(lambda module, input, output: output + 1)
print(module(torch.tensor(1))) # and it is 4 indeed
print(module.forward(torch.tensor(1))) # here it is 3 still
如果您调用__call__
方法就很好了,forward
不会运行该钩子(因此您会得到3
)。
您不太可能希望在register_hook
的实例上使用super
,但让我们考虑这样的示例:
def increment_by_one(module, input, output):
return output + 1
class Parent(torch.nn.Module):
def forward(self, tensor):
return tensor + 1
class Child(Parent):
def forward(self, tensor):
# Increment by `1` from Parent
super().register_forward_hook(increment_by_one)
return super().forward(tensor) + 1
module = Child()
# Increment output by 1 so we should get `5` in total
module.register_forward_hook(increment_by_one)
print(module(torch.tensor(1))) # and it is 5 indeed
print(module.forward(torch.tensor(1))) # here is 3
使用super().forward(...)
很好,即使钩子也可以正常工作(这是使用__call__
而不是forward
的主要思想)。
顺便说一句。调用super().__call__(...)
会引发InifiniteRecursion
错误。